Skip to content

Commit

Permalink
ENH: Add MEG overlay to coreg
Browse files Browse the repository at this point in the history
  • Loading branch information
larsoner committed Oct 10, 2023
1 parent 3b6a339 commit 60aefda
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 42 deletions.
132 changes: 94 additions & 38 deletions mne/gui/_coreg.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,10 @@ class CoregistrationUI(HasTraits):
If True, display the head shape points. Defaults to True.
eeg_channels : bool
If True, display the EEG channels. Defaults to True.
meg_channels : bool
If True, display the MEG channels. Defaults to False.
fnirs_channels : bool
If True, display the MEG channels. Defaults to True.
orient_glyphs : bool
If True, orient the sensors towards the head surface. Default to False.
scale_by_distance : bool
Expand Down Expand Up @@ -154,6 +158,8 @@ class CoregistrationUI(HasTraits):
_hpi_coils = Bool()
_head_shape_points = Bool()
_eeg_channels = Bool()
_meg_channels = Bool()
_fnirs_channels = Bool()
_head_resolution = Bool()
_head_opacity = Float()
_helmet = Bool()
Expand All @@ -178,6 +184,8 @@ def __init__(
hpi_coils=None,
head_shape_points=None,
eeg_channels=None,
meg_channels=None,
fnirs_channels=None,
orient_glyphs=None,
scale_by_distance=None,
mark_inside=None,
Expand Down Expand Up @@ -232,6 +240,8 @@ def _get_default(var, val):
hpi_coils=_get_default(hpi_coils, True),
head_shape_points=_get_default(head_shape_points, True),
eeg_channels=_get_default(eeg_channels, True),
meg_channels=_get_default(meg_channels, False),
fnirs_channels=_get_default(fnirs_channels, False),
head_resolution=_get_default(head_resolution, True),
head_opacity=_get_default(head_opacity, 0.8),
helmet=False,
Expand Down Expand Up @@ -304,6 +314,8 @@ def _get_default(var, val):
self._set_hpi_coils(self._defaults["hpi_coils"])
self._set_head_shape_points(self._defaults["head_shape_points"])
self._set_eeg_channels(self._defaults["eeg_channels"])
self._set_meg_channels(self._defaults["meg_channels"])
self._set_fnirs_channels(self._defaults["fnirs_channels"])
self._set_head_resolution(self._defaults["head_resolution"])
self._set_helmet(self._defaults["helmet"])
self._set_grow_hair(self._defaults["grow_hair"])
Expand Down Expand Up @@ -352,7 +364,7 @@ def _get_default(var, val):
True: dict(azimuth=90, elevation=90), # front
False: dict(azimuth=180, elevation=90),
} # left
self._renderer.set_camera(distance=None, **views[self._lock_fids])
self._renderer.set_camera(distance="auto", **views[self._lock_fids])
self._redraw()
# XXX: internal plotter/renderer should not be exposed
if not self._immediate_redraw:
Expand Down Expand Up @@ -483,6 +495,12 @@ def _set_head_shape_points(self, state):
def _set_eeg_channels(self, state):
self._eeg_channels = bool(state)

def _set_meg_channels(self, state):
self._meg_channels = bool(state)

def _set_fnirs_channels(self, state):
self._fnirs_channels = bool(state)

def _set_head_resolution(self, state):
self._head_resolution = bool(state)

Expand Down Expand Up @@ -568,6 +586,8 @@ def _set_point_weight(self, weight, point):
"hpi": "_set_hpi_coils",
"hsp": "_set_head_shape_points",
"eeg": "_set_eeg_channels",
"meg": "_set_meg_channels",
"fnirs": "_set_fnirs_channels",
}
if point in funcs.keys():
getattr(self, funcs[point])(weight > 0)
Expand Down Expand Up @@ -612,6 +632,7 @@ def _lock_fids_changed(self, change=None):
"save_mri_fids",
# View options
"helmet",
"meg",
"head_opacity",
"high_res_head",
# Digitization source
Expand Down Expand Up @@ -705,11 +726,11 @@ def _info_file_changed(self, change=None):

@observe("_orient_glyphs")
def _orient_glyphs_changed(self, change=None):
self._update_plot(["hpi", "hsp", "eeg"])
self._update_plot(["hpi", "hsp", "sensors"])

@observe("_scale_by_distance")
def _scale_by_distance_changed(self, change=None):
self._update_plot(["hpi", "hsp", "eeg"])
self._update_plot(["hpi", "hsp", "sensors"])

@observe("_mark_inside")
def _mark_inside_changed(self, change=None):
Expand All @@ -725,7 +746,15 @@ def _head_shape_point_changed(self, change=None):

@observe("_eeg_channels")
def _eeg_channels_changed(self, change=None):
self._update_plot("eeg")
self._update_plot("sensors")

@observe("_meg_channels")
def _meg_channels_changed(self, change=None):
self._update_plot("sensors")

@observe("_fnirs_channels")
def _fnirs_channels_changed(self, change=None):
self._update_plot("sensors")

@observe("_head_resolution")
def _head_resolution_changed(self, change=None):
Expand Down Expand Up @@ -826,6 +855,7 @@ def _configure_legend(self):
mri_fids_legend_actor = self._renderer.legend(labels=labels)
self._update_actor("mri_fids_legend", mri_fids_legend_actor)

@safe_event
@verbose
def _redraw(self, *, verbose=None):
if not self._redraws_pending:
Expand All @@ -835,7 +865,7 @@ def _redraw(self, *, verbose=None):
mri_fids=self._add_mri_fiducials,
hsp=self._add_head_shape_points,
hpi=self._add_hpi_coils,
eeg=self._add_eeg_fnirs_channels,
sensors=self._add_channels,
head_fids=self._add_head_fiducials,
helmet=self._add_helmet,
)
Expand Down Expand Up @@ -958,7 +988,7 @@ def _update_plot(self, changes="all", verbose=None):
"mri_fids", # MRI first
"hsp",
"hpi",
"eeg",
"sensors",
"head_fids", # then dig
"helmet",
)
Expand Down Expand Up @@ -1042,7 +1072,7 @@ def _follow_fiducial_view(self):
kwargs = dict(front=(90.0, 90.0), left=(180, 90), right=(0.0, 90))
kwargs = dict(zip(("azimuth", "elevation"), kwargs[view[fid]]))
if not self._lock_fids:
self._renderer.set_camera(distance=None, **kwargs)
self._renderer.set_camera(distance="auto", **kwargs)

def _update_fiducials(self):
fid = self._current_fiducial
Expand Down Expand Up @@ -1146,7 +1176,13 @@ def _forward_widget_command(
return ret

def _set_sensors_visibility(self, state):
sensors = ["head_fiducials", "hpi_coils", "head_shape_points", "eeg_channels"]
sensors = [
"head_fiducials",
"hpi_coils",
"head_shape_points",
"sensors",
"helmet",
]
for sensor in sensors:
if sensor in self._actors and self._actors[sensor] is not None:
actors = self._actors[sensor]
Expand All @@ -1157,7 +1193,12 @@ def _set_sensors_visibility(self, state):

def _update_actor(self, actor_name, actor):
# XXX: internal plotter/renderer should not be exposed
self._renderer.plotter.remove_actor(self._actors.get(actor_name), render=False)
# Work around PyVista sequential update bug with iterable until > 0.42.3 is req
actors = self._actors.get(actor_name) or []
if not isinstance(actors, list):
actors = [actors]
for this_actor in actors:
self._renderer.plotter.remove_actor(this_actor, render=False)
self._actors[actor_name] = actor

def _add_mri_fiducials(self):
Expand Down Expand Up @@ -1217,35 +1258,43 @@ def _add_head_shape_points(self):
hsp_actors = None
self._update_actor("head_shape_points", hsp_actors)

def _add_eeg_fnirs_channels(self):
def _add_channels(self):
plot_types = dict(eeg=False, meg=False, fnirs=False)
if self._eeg_channels:
eeg = ["original"]
picks = pick_types(self._info, eeg=(len(eeg) > 0), fnirs=True)
if len(picks) > 0:
actors = _plot_sensors(
self._renderer,
self._info,
self._to_cf_t,
picks,
meg=False,
eeg=eeg,
fnirs=["sources", "detectors"],
warn_meg=False,
head_surf=self._head_geo,
units="m",
sensor_opacity=self._defaults["sensor_opacity"],
orient_glyphs=self._orient_glyphs,
scale_by_distance=self._scale_by_distance,
surf=self._head_geo,
check_inside=self._check_inside,
nearest=self._nearest,
)
sens_actors = sum(actors.values(), list())
else:
sens_actors = None
else:
sens_actors = None
self._update_actor("eeg_channels", sens_actors)
plot_types["eeg"] = ["original"]
if self._meg_channels:
plot_types["meg"] = ["sensors"]
if self._fnirs_channels:
plot_types["fnirs"] = ["sources", "detectors"]
sens_actors = list()
# until opacity can be specified using a dict, we need to iterate
sensor_opacity = dict(
eeg=self._defaults["sensor_opacity"],
fnirs=self._defaults["sensor_opacity"],
meg=0.25,
)
for ch_type, plot_type in plot_types.items():
picks = pick_types(self._info, ref_meg=False, **{ch_type: True})
if not (len(picks) and plot_type):
continue
these_actors = _plot_sensors(
self._renderer,
self._info,
self._to_cf_t,
picks=picks,
warn_meg=False,
head_surf=self._head_geo,
units="m",
sensor_opacity=sensor_opacity[ch_type],
orient_glyphs=self._orient_glyphs,
scale_by_distance=self._scale_by_distance,
surf=self._head_geo,
check_inside=self._check_inside,
nearest=self._nearest,
**plot_types,
)
sens_actors.extend(sum(these_actors.values(), list()))
self._update_actor("sensors", sens_actors)

def _add_head_surface(self):
bem = None
Expand Down Expand Up @@ -1336,7 +1385,7 @@ def _fits_icp(self):
def _fit_icp_real(self, *, update_head):
with self._lock(params=True, fitting=True):
self._current_icp_iterations = 0
updates = ["hsp", "hpi", "eeg", "head_fids", "helmet"]
updates = ["hsp", "hpi", "sensors", "head_fids", "helmet"]
if update_head:
updates.insert(0, "head")

Expand Down Expand Up @@ -1707,6 +1756,13 @@ def _configure_dock(self):
tooltip="Enable/Disable MEG helmet",
layout=view_options_layout,
)
self._widgets["meg"] = self._renderer._dock_add_check_box(
name="Show MEG sensors",
value=self._helmet,
callback=self._set_meg_channels,
tooltip="Enable/Disable MEG sensors",
layout=view_options_layout,
)
self._widgets["high_res_head"] = self._renderer._dock_add_check_box(
name="Show high-resolution head",
value=self._head_resolution,
Expand Down
10 changes: 6 additions & 4 deletions mne/viz/_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -1597,7 +1597,7 @@ def _sensor_shape(coil):
except ImportError: # scipy < 1.8
from scipy.spatial.qhull import QhullError
id_ = coil["type"] & 0xFFFF
pad = True
z_value = 0
# Square figure eight
if id_ in (
FIFF.FIFFV_COIL_NM_122,
Expand All @@ -1623,6 +1623,8 @@ def _sensor_shape(coil):
tris = np.concatenate(
(_make_tris_fan(4), _make_tris_fan(4)[:, ::-1] + 4), axis=0
)
# Offset for visibility (using heuristic for sanely named Neuromag coils)
z_value = 0.001 * (1 + coil["chname"].endswith("2"))
# Square
elif id_ in (
FIFF.FIFFV_COIL_POINT_MAGNETOMETER,
Expand Down Expand Up @@ -1693,11 +1695,11 @@ def _sensor_shape(coil):
rr_rot = rrs @ u
tris = Delaunay(rr_rot[:, :2]).simplices
tris = np.concatenate((tris, tris[:, ::-1]))
pad = False
z_value = None

# Go from (x,y) -> (x,y,z)
if pad:
rrs = np.pad(rrs, ((0, 0), (0, 1)), mode="constant")
if z_value is not None:
rrs = np.pad(rrs, ((0, 0), (0, 1)), mode="constant", constant_values=z_value)
assert rrs.ndim == 2 and rrs.shape[1] == 3
return rrs, tris

Expand Down

0 comments on commit 60aefda

Please sign in to comment.