Skip to content

Commit

Permalink
BUG: Fix bug with sensor_colors
Browse files Browse the repository at this point in the history
  • Loading branch information
larsoner committed Oct 4, 2023
1 parent a8b4638 commit 87c47b3
Show file tree
Hide file tree
Showing 8 changed files with 141 additions and 127 deletions.
2 changes: 2 additions & 0 deletions doc/changes/devel.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ Enhancements
- Add the possibility to provide a float between 0 and 1 as ``n_grad``, ``n_mag`` and ``n_eeg`` in `~mne.compute_proj_raw`, `~mne.compute_proj_epochs` and `~mne.compute_proj_evoked` to select the number of vectors based on the cumulative explained variance (:gh:`11919` by `Mathieu Scheltienne`_)
- Added support for Artinis fNIRS data files to :func:`mne.io.read_raw_snirf` (:gh:`11926` by `Robert Luke`_)
- Add helpful error messages when using methods on empty :class:`mne.Epochs`-objects (:gh:`11306` by `Martin Schulz`_)
- Add support for passing a :class:`python:dict` as ``sensor_color`` to specify per-channel-type colors in :func:`mne.viz.plot_alignment` (:gh:`12067` by `Eric Larson`)
- Add inferring EEGLAB files' montage unit automatically based on estimated head radius using :func:`read_raw_eeglab(..., montage_units="auto") <mne.io.read_raw_eeglab>` (:gh:`11925` by `Jack Zhang`_, :gh:`11951` by `Eric Larson`_)
- Add :class:`~mne.time_frequency.EpochsSpectrumArray` and :class:`~mne.time_frequency.SpectrumArray` to support creating power spectra from :class:`NumPy array <numpy.ndarray>` data (:gh:`11803` by `Alex Rockhill`_)
- Add support for writing forward solutions to HDF5 and convenience function :meth:`mne.Forward.save` (:gh:`12036` by `Eric Larson`_)
Expand All @@ -56,6 +57,7 @@ Bugs
- Fix bug with axis clip box boundaries in :func:`mne.viz.plot_evoked_topo` and related functions (:gh:`11999` by `Eric Larson`_)
- Fix bug with ``subject_info`` when loading data from and exporting to EDF file (:gh:`11952` by `Paul Roujansky`_)
- Fix bug with delayed checking of :class:`info["bads"] <mne.Info>` (:gh:`12038` by `Eric Larson`_)
- Fix bug with :func:`mne.viz.plot_alignment` where ``sensor_colors`` were not handled properly on a per-channel-type basis (:gh:`12067` by `Eric Larson`)
- Fix handling of channel information in annotations when loading data from and exporting to EDF file (:gh:`11960` :gh:`12017` :gh:`12044` by `Paul Roujansky`_)
- Add missing ``overwrite`` and ``verbose`` parameters to :meth:`Transform.save() <mne.transforms.Transform.save>` (:gh:`12004` by `Marijn van Vliet`_)
- Fix parsing of eye-link :class:`~mne.Annotations` when ``apply_offsets=False`` is provided to :func:`~mne.io.read_raw_eyelink` (:gh:`12003` by `Mathieu Scheltienne`_)
Expand Down
3 changes: 3 additions & 0 deletions mne/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@
coreg=dict(
mri_fid_opacity=1.0,
dig_fid_opacity=1.0,
# go from unit scaling (e.g., unit-radius sphere) to meters
mri_fid_scale=5e-3,
dig_fid_scale=8e-3,
extra_scale=4e-3,
Expand All @@ -235,6 +236,8 @@
eegp_height=0.1,
ecog_scale=5e-3,
seeg_scale=5e-3,
meg_scale=1.0, # sensors are already in SI units
ref_meg_scale=1.0,
dbs_scale=5e-3,
fnirs_scale=5e-3,
source_scale=5e-3,
Expand Down
7 changes: 3 additions & 4 deletions mne/gui/_coreg.py
Original file line number Diff line number Diff line change
Expand Up @@ -835,7 +835,7 @@ def _redraw(self, *, verbose=None):
mri_fids=self._add_mri_fiducials,
hsp=self._add_head_shape_points,
hpi=self._add_hpi_coils,
eeg=self._add_eeg_channels,
eeg=self._add_eeg_fnirs_channels,
head_fids=self._add_head_fiducials,
helmet=self._add_helmet,
)
Expand Down Expand Up @@ -1217,7 +1217,7 @@ def _add_head_shape_points(self):
hsp_actors = None
self._update_actor("head_shape_points", hsp_actors)

def _add_eeg_channels(self):
def _add_eeg_fnirs_channels(self):
if self._eeg_channels:
eeg = ["original"]
picks = pick_types(self._info, eeg=(len(eeg) > 0), fnirs=True)
Expand All @@ -1240,8 +1240,7 @@ def _add_eeg_channels(self):
check_inside=self._check_inside,
nearest=self._nearest,
)
sens_actors = actors["eeg"]
sens_actors.extend(actors["fnirs"])
sens_actors = sum(actors.values(), list())
else:
sens_actors = None
else:
Expand Down
216 changes: 110 additions & 106 deletions mne/viz/_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#
# License: Simplified BSD

from collections import defaultdict
import os
import os.path as op
import warnings
Expand Down Expand Up @@ -604,11 +605,20 @@ def plot_alignment(
.. versionadded:: 0.16
.. versionchanged:: 1.0
Defaults to ``'terrain'``.
sensor_colors : array-like | None
Colors to use for the sensor glyphs. Can be list-like of color strings
(length ``n_sensors``) or array-like of RGB(A) values (shape
``(n_sensors, 3)`` or ``(n_sensors, 4)``). ``None`` (the default) uses
the default sensor colors for the :func:`~mne.viz.plot_alignment` GUI.
sensor_colors : array-like of color | dict | None
Colors to use for the sensor glyphs. Can be None (default) to use default
colors. A dict should provide the channel type to color mapping, e.g.::
dict(eeg=eeg_colors)
Where the value (e.g., ``eeg_colors`` above) can be broadcast to a matplotlib
array of colors with length that matches the number of channels of
that type, i.e., is compatible with :func:`matplotlib.colors.to_rgba_array`.
For example, this could be the string ``"k"``, a list of ``n_eeg`` color
strings, or an NumPy ndarray of shape ``(n_eeg, 3)`` or ``(n_eeg, 4)``.
.. versionchanged:: 1.6
Support for passing a ``dict`` was added.
%(verbose)s
Returns
Expand Down Expand Up @@ -1437,29 +1447,16 @@ def _plot_sensors(
sensor_colors=None,
):
"""Render sensors in a 3D scene."""
from matplotlib.colors import to_rgba_array

defaults = DEFAULTS["coreg"]
ch_pos, sources, detectors = _ch_pos_in_coord_frame(
pick_info(info, picks), to_cf_t=to_cf_t, warn_meg=warn_meg
)

actors = dict(
meg=list(),
ref_meg=list(),
eeg=list(),
fnirs=list(),
ecog=list(),
seeg=list(),
dbs=list(),
)
locs = dict(
eeg=list(),
fnirs=list(),
ecog=list(),
seeg=list(),
source=list(),
detector=list(),
)
scalar = 1 if units == "m" else 1e3
actors = defaultdict(lambda: list())
locs = defaultdict(lambda: list())
unit_scalar = 1 if units == "m" else 1e3
for ch_name, ch_coord in ch_pos.items():
ch_type = channel_type(info, info.ch_names.index(ch_name))
# for default picking
Expand All @@ -1471,119 +1468,126 @@ def _plot_sensors(
plot_sensors = (ch_type != "fnirs" or "channels" in fnirs) and (
ch_type != "eeg" or "original" in eeg
)
color = defaults[ch_type + "_color"]
# plot sensors
if isinstance(ch_coord, tuple): # is meg, plot coil
verts, triangles = ch_coord
actor, _ = renderer.surface(
surface=dict(rr=verts * scalar, tris=triangles),
color=color,
opacity=0.25,
backface_culling=True,
)
actors[ch_type].append(actor)
else:
if plot_sensors:
locs[ch_type].append(ch_coord)
ch_coord = dict(rr=ch_coord[0] * unit_scalar, tris=ch_coord[1])
if plot_sensors:
locs[ch_type].append(ch_coord)
if ch_name in sources and "sources" in fnirs:
locs["source"].append(sources[ch_name])
if ch_name in detectors and "detectors" in fnirs:
locs["detector"].append(detectors[ch_name])
# Plot these now
if ch_name in sources and ch_name in detectors and "pairs" in fnirs:
actor, _ = renderer.tube( # array of origin and dest points
origin=sources[ch_name][np.newaxis] * scalar,
destination=detectors[ch_name][np.newaxis] * scalar,
radius=0.001 * scalar,
origin=sources[ch_name][np.newaxis] * unit_scalar,
destination=detectors[ch_name][np.newaxis] * unit_scalar,
radius=0.001 * unit_scalar,
)
actors[ch_type].append(actor)
del ch_type

# add sensors
for sensor_type in locs.keys():
if len(locs[sensor_type]) > 0:
sens_loc = np.array(locs[sensor_type])
sens_loc = sens_loc[~np.isnan(sens_loc).any(axis=1)]
scale = defaults[sensor_type + "_scale"]
if sensor_colors is None:
color = defaults[sensor_type + "_color"]
# now actually plot the sensors
extra = ""
types = (dict, None)
if len(locs) == 0:
return
elif len(locs) == 1:
# Upsample from array-like to dict when there is one channel type
extra = "(or array-like since only one sensor type is plotted)"
if sensor_colors is not None and not isinstance(sensor_colors, dict):
sensor_colors = {
list(locs)[0]: to_rgba_array(sensor_colors),
}
else:
extra = f"when more than one channel type ({list(locs)}) is plotted"
_validate_type(sensor_colors, types, "sensor_colors", extra=extra)
del extra, types
if sensor_colors is None:
sensor_colors = dict()
assert isinstance(sensor_colors, dict)
for ch_type, sens_loc in locs.items():
assert len(sens_loc) # should be guaranteed above
colors = to_rgba_array(sensor_colors.get(ch_type, defaults[ch_type + "_color"]))
_check_option(
f"len(sensor_colors[{repr(ch_type)}])",
colors.shape[0],
(len(sens_loc), 1),
)
scale = defaults[ch_type + "_scale"] * unit_scalar
if isinstance(sens_loc[0], dict): # meg coil
if len(colors) == 1:
colors = [colors[0]] * len(sens_loc)
for surface, color in zip(sens_loc, colors):
actor, _ = renderer.surface(
surface=surface,
color=color[:3],
opacity=0.25 * color[3],
backface_culling=False, # visible from all sides
)
actors[ch_type].append(actor)
else:
sens_loc = np.array(sens_loc, float)
mask = ~np.isnan(sens_loc).any(axis=1)
if len(colors) == 1:
# Single color mode (one actor)
actor, _ = _plot_glyphs(
renderer=renderer,
loc=sens_loc * scalar,
color=color,
scale=scale * scalar,
opacity=sensor_opacity,
loc=sens_loc[mask] * unit_scalar,
color=colors[0, :3],
scale=scale,
opacity=sensor_opacity * colors[0, 3],
orient_glyphs=orient_glyphs,
scale_by_distance=scale_by_distance,
project_points=project_points,
surf=surf,
check_inside=check_inside,
nearest=nearest,
)
if sensor_type in ("source", "detector"):
sensor_type = "fnirs"
actors[sensor_type].append(actor)
actors[ch_type].append(actor)
else:
actor_list = []
for idx_sen in range(sens_loc.shape[0]):
sensor_colors = np.asarray(sensor_colors)
if (
sensor_colors.ndim not in (1, 2)
or sensor_colors.shape[0] != sens_loc.shape[0]
):
raise ValueError(
"sensor_colors should either be None or be "
"array-like with shape (n_sensors,) or "
"(n_sensors, 3) or (n_sensors, 4). Got shape "
f"{sensor_colors.shape}."
)
color = sensor_colors[idx_sen]

# Multi-color mode (multiple actors)
for loc, color, usable in zip(sens_loc, colors, mask):
if not usable:
continue
actor, _ = _plot_glyphs(
renderer=renderer,
loc=(sens_loc * scalar)[idx_sen, :],
color=color,
scale=scale * scalar,
opacity=sensor_opacity,
loc=loc * unit_scalar,
color=color[:3],
scale=scale,
opacity=sensor_opacity * color[3],
orient_glyphs=orient_glyphs,
scale_by_distance=scale_by_distance,
project_points=project_points,
surf=surf,
check_inside=check_inside,
nearest=nearest,
)
actor_list.append(actor)
if sensor_type in ("source", "detector"):
sensor_type = "fnirs"
actors[sensor_type].append(actor_list)

# add projected eeg
eeg_indices = pick_types(info, eeg=True)
if eeg_indices.size > 0 and "projected" in eeg:
logger.info("Projecting sensors to the head surface")
eeg_loc = np.array([ch_pos[info.ch_names[idx]] for idx in eeg_indices])
eeg_loc = eeg_loc[~np.isnan(eeg_loc).any(axis=1)]
eegp_loc, eegp_nn = _project_onto_surface(
eeg_loc, head_surf, project_rrs=True, return_nn=True
)[2:4]
del eeg_loc
eegp_loc *= scalar
scale = defaults["eegp_scale"] * scalar
actor, _ = renderer.quiver3d(
x=eegp_loc[:, 0],
y=eegp_loc[:, 1],
z=eegp_loc[:, 2],
u=eegp_nn[:, 0],
v=eegp_nn[:, 1],
w=eegp_nn[:, 2],
color=defaults["eegp_color"],
mode="cylinder",
scale=scale,
opacity=0.6,
glyph_height=defaults["eegp_height"],
glyph_center=(0.0, -defaults["eegp_height"] / 2.0, 0),
glyph_resolution=20,
backface_culling=True,
)
actors["eeg"].append(actor)
actors[ch_type].append(actor)
if ch_type == "eeg" and "projected" in eeg:
logger.info("Projecting sensors to the head surface")
eegp_loc, eegp_nn = _project_onto_surface(
sens_loc[mask], head_surf, project_rrs=True, return_nn=True
)[2:4]
eegp_loc *= unit_scalar
actor, _ = renderer.quiver3d(
x=eegp_loc[:, 0],
y=eegp_loc[:, 1],
z=eegp_loc[:, 2],
u=eegp_nn[:, 0],
v=eegp_nn[:, 1],
w=eegp_nn[:, 2],
color=defaults["eegp_color"],
mode="cylinder",
scale=defaults["eegp_scale"] * unit_scalar,
opacity=0.6,
glyph_height=defaults["eegp_height"],
glyph_center=(0.0, -defaults["eegp_height"] / 2.0, 0),
glyph_resolution=20,
backface_culling=True,
)
actors["eeg"].append(actor)
actors = dict(actors) # get rid of defaultdict

return actors

Expand Down
23 changes: 11 additions & 12 deletions mne/viz/backends/_pyvista.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,17 +373,18 @@ def polydata(
polygon_offset=None,
**kwargs,
):
from matplotlib.colors import to_rgba_array

with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=FutureWarning)
rgba = False
if color is not None and len(color) == mesh.n_points:
if color.shape[1] == 3:
scalars = np.c_[color, np.ones(mesh.n_points)]
else:
scalars = color
scalars = (scalars * 255).astype("ubyte")
color = None
rgba = True
if color is not None:
# See if we need to convert or not
check_color = to_rgba_array(color)
if len(check_color) == mesh.n_points:
scalars = (check_color * 255).astype("ubyte")
color = None
rgba = True
if isinstance(colormap, np.ndarray):
if colormap.dtype == np.uint8:
colormap = colormap.astype(np.float64) / 255.0
Expand All @@ -395,24 +396,22 @@ def polydata(
mesh.GetPointData().SetActiveNormals("Normals")
else:
_compute_normals(mesh)
if "rgba" in kwargs:
rgba = kwargs["rgba"]
kwargs.pop("rgba")
smooth_shading = self.smooth_shading
if representation == "wireframe":
smooth_shading = False # never use smooth shading for wf
rgba = kwargs.pop("rgba", rgba)
actor = _add_mesh(
plotter=self.plotter,
mesh=mesh,
color=color,
scalars=scalars,
edge_color=color,
rgba=rgba,
opacity=opacity,
cmap=colormap,
backface_culling=backface_culling,
rng=[vmin, vmax],
show_scalar_bar=False,
rgba=rgba,
smooth_shading=smooth_shading,
interpolate_before_map=interpolate_before_map,
style=representation,
Expand Down
Loading

0 comments on commit 87c47b3

Please sign in to comment.