Skip to content

Commit

Permalink
MRG: Refactor plot_alignment (#9714)
Browse files Browse the repository at this point in the history
* start with _plot_head_surface

* refactor

* add _plot_axes

* update

* fix

* touch example

* fix

* update

* return actor

* add _plot_forward

* add _plot_mri_fiducials

* return actors
  • Loading branch information
GuillaumeFavelier authored Sep 6, 2021
1 parent b20f6b3 commit 3e7914e
Show file tree
Hide file tree
Showing 3 changed files with 163 additions and 115 deletions.
274 changes: 161 additions & 113 deletions mne/viz/_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,8 +526,6 @@ def plot_alignment(info=None, trans=None, subject=None, subjects_dir=None,
.. versionadded:: 0.15
"""
from ..forward import Forward
from ..coreg import get_mni_fiducials
# Update the backend
from .backends.renderer import _get_renderer

Expand Down Expand Up @@ -572,14 +570,6 @@ def plot_alignment(info=None, trans=None, subject=None, subjects_dir=None,
if src_subject is not None and subject != src_subject:
raise ValueError(f'subject ("{subject}") did not match the '
f'subject name in src ("{src_subject}")')
if fwd is not None:
_validate_type(fwd, [Forward])
fwd_rr = fwd['source_rr']
if fwd['source_ori'] == FIFF.FIFFV_MNE_FIXED_ORI:
fwd_nn = fwd['source_nn'].reshape(-1, 1, 3)
else:
fwd_nn = fwd['source_nn'].reshape(-1, 3, 3)

# configure transforms
if trans == 'auto':
subjects_dir = get_subjects_dir(subjects_dir, raise_error=True)
Expand Down Expand Up @@ -645,23 +635,18 @@ def plot_alignment(info=None, trans=None, subject=None, subjects_dir=None,
surfs[hemi] = _read_mri_surface(brain_fname)

# Head surface:
head = [s for s in surfaces if s in
('auto', 'head', 'outer_skin', 'head-dense', 'seghead')]
head_keys = ('auto', 'head', 'outer_skin', 'head-dense', 'seghead')
head = [s for s in surfaces if s in head_keys]
if len(head) > 1:
raise ValueError('Can only supply one head-like surface name, '
f'got {head}')
head = head[0] if head else False
if head is not False:
surfaces.pop(surfaces.index(head))
head_surf = _get_head_surface(head, subject, subjects_dir, bem=bem)
if head_surf is not None:
surfs['head'] = head_surf
elif 'projected' in eeg:
raise ValueError('A head surface is required to project EEG, '
'"head", "outer_skin", "head-dense" or "seghead" '
'must be in surfaces or surfaces must be "auto"')
else:
head_surf = None

# Skull surface:
skulls = [s for s in surfaces if s in ('outer_skull', 'inner_skull')]
Expand Down Expand Up @@ -693,9 +678,8 @@ def plot_alignment(info=None, trans=None, subject=None, subjects_dir=None,
head_alpha = max_alpha
else:
head_alpha = alpha_range[0]
alphas = dict(head=head_alpha, helmet=0.25, lh=hemi_val, rh=hemi_val)
colors = dict(head=DEFAULTS['coreg']['head_color'],
helmet=DEFAULTS['coreg']['helmet_color'],
alphas = dict(helmet=0.25, lh=hemi_val, rh=hemi_val)
colors = dict(helmet=DEFAULTS['coreg']['helmet_color'],
lh=(0.5,) * 3, rh=(0.5,) * 3)
for idx, name in enumerate(skulls):
alphas[name] = alpha_range[idx + 1]
Expand All @@ -706,6 +690,8 @@ def plot_alignment(info=None, trans=None, subject=None, subjects_dir=None,
for k, v in user_alpha.items():
if v is not None:
alphas[k] = v
if k in head_keys and v is not None:
head_alpha = v
fid_colors = tuple(
defaults[f'{key}_color'] for key in ('lpa', 'nasion', 'rpa'))

Expand All @@ -715,101 +701,35 @@ def plot_alignment(info=None, trans=None, subject=None, subjects_dir=None,
if interaction == 'terrain':
renderer.set_interaction('terrain')

# plot head
_, head_surf = _plot_head_surface(
renderer, head, subject, subjects_dir, bem, coord_frame,
to_cf_t, alpha=head_alpha)

# plot surfaces
for key, surf in surfs.items():
# Surfs can sometimes be in head coords (e.g., if coming from sphere)
assert isinstance(surf, dict), f'{key}: {type(surf)}'
surf = transform_surface_to(
surf, coord_frame, [to_cf_t['mri'], to_cf_t['head']], copy=True)
renderer.surface(surface=surf, color=colors[key],
opacity=alphas[key],
backface_culling=(key != 'helmet'))
if brain and 'lh' not in surfs: # one layer sphere
assert bem['coord_frame'] == FIFF.FIFFV_COORD_HEAD
center = bem['r0'].copy()
center = apply_trans(to_cf_t['head'], center)
renderer.sphere(center, scale=0.01, color=colors['lh'],
opacity=alphas['lh'])
if show_axes:
axes = [(to_cf_t['head'], (0.9, 0.3, 0.3))] # always show head
if not np.allclose(head_mri_t['trans'], np.eye(4)): # Show MRI
axes.append((to_cf_t['mri'], (0.6, 0.6, 0.6)))
if pick_types(info, meg=True).size > 0: # Show MEG
axes.append((to_cf_t['meg'], (0., 0.6, 0.6)))
for ax in axes:
x, y, z = np.tile(ax[0]['trans'][:3, 3], 3).reshape((3, 3)).T
u, v, w = ax[0]['trans'][:3, :3]
renderer.sphere(center=np.column_stack((x[0], y[0], z[0])),
color=ax[1], scale=3e-3)
renderer.quiver3d(x=x, y=y, z=z, u=u, v=v, w=w, mode='arrow',
scale=2e-2, color=ax[1],
scale_mode='scalar', resolution=20,
scalars=[0.33, 0.66, 1.0])
_plot_axes(renderer, info, to_cf_t, head_mri_t)

# plot points
_check_option('dig', dig, (True, False, 'fiducials'))
if dig:
no_ext_or_hpi = True
if dig is True:
hpi_loc = np.array([
d['r'] for d in (info['dig'] or [])
if (d['kind'] == FIFF.FIFFV_POINT_HPI and
d['coord_frame'] == FIFF.FIFFV_COORD_HEAD)])
hpi_loc = apply_trans(to_cf_t['head'], hpi_loc)
renderer.sphere(center=hpi_loc, color=defaults['hpi_color'],
scale=defaults['hpi_scale'], opacity=0.5,
backface_culling=True)
ext_loc = np.array([
d['r'] for d in (info['dig'] or [])
if (d['kind'] == FIFF.FIFFV_POINT_EXTRA and
d['coord_frame'] == FIFF.FIFFV_COORD_HEAD)])
ext_loc = apply_trans(to_cf_t['head'], ext_loc)
renderer.sphere(center=ext_loc, color=defaults['extra_color'],
scale=defaults['extra_scale'], opacity=0.25,
backface_culling=True)
no_ext_or_hpi = len(hpi_loc) + len(ext_loc) == 0
car_loc = _fiducial_coords(info['dig'], FIFF.FIFFV_COORD_HEAD)
car_loc = apply_trans(to_cf_t['head'], car_loc)
if len(car_loc) == 0 and no_ext_or_hpi:
warn('Digitization points not found. Cannot plot digitization.')
for color, data in zip(fid_colors, car_loc):
renderer.sphere(center=data, color=color,
scale=defaults['dig_fid_scale'],
opacity=defaults['dig_fid_opacity'],
backface_culling=True)
_plot_head_shape_points(renderer, info, to_cf_t)
_plot_head_fiducials(renderer, info, to_cf_t, fid_colors)

if mri_fiducials:
if mri_fiducials is True:
subjects_dir = get_subjects_dir(subjects_dir, raise_error=True)
if subject is None:
raise ValueError("Subject needs to be specified to "
"automatically find the fiducials file.")
mri_fiducials = op.join(subjects_dir, subject, 'bem',
subject + '-fiducials.fif')
if isinstance(mri_fiducials, str):
if mri_fiducials == 'estimated':
mri_fiducials = get_mni_fiducials(subject, subjects_dir)
else:
mri_fiducials, cf = read_fiducials(mri_fiducials)
if cf != FIFF.FIFFV_COORD_MRI:
raise ValueError("Fiducials are not in MRI space")
fid_loc = _fiducial_coords(mri_fiducials, FIFF.FIFFV_COORD_MRI)
fid_loc = apply_trans(to_cf_t['mri'], fid_loc)
transform = np.eye(4)
transform[:3, :3] = to_cf_t['mri']['trans'][:3, :3] * \
defaults['mri_fid_scale']
# rotate around Z axis 45 deg first
transform = transform @ rotation(0, 0, np.pi / 4)
for color, data in zip(fid_colors, fid_loc):
renderer.quiver3d(
x=data[0], y=data[1], z=data[2],
u=1., v=0., w=0., color=color, mode='oct',
scale=1., opacity=defaults['mri_fid_opacity'],
backface_culling=True, solid_transform=transform)
_plot_mri_fiducials(renderer, mri_fiducials, subjects_dir, subject,
to_cf_t, fid_colors)

# plot sensors
if picks.size > 0:
_plot_sensors(info, to_cf_t, renderer, picks, meg, eeg, fnirs,
_plot_sensors(renderer, info, to_cf_t, picks, meg, eeg, fnirs,
warn_meg, head_surf, 'm')

if src is not None:
Expand Down Expand Up @@ -846,21 +766,7 @@ def plot_alignment(info=None, trans=None, subject=None, subjects_dir=None,
backface_culling=True)

if fwd is not None:
# update coordinate frame
fwd_trans = to_cf_t[_frame_to_str[fwd['coord_frame']]]
fwd_rr = apply_trans(fwd_trans, fwd_rr)
fwd_nn = apply_trans(fwd_trans, fwd_nn, move=False)
red = (1.0, 0.0, 0.0)
green = (0.0, 1.0, 0.0)
blue = (0.0, 0.0, 1.0)
for ori, color in zip(range(fwd_nn.shape[1]), (red, green, blue)):
renderer.quiver3d(fwd_rr[:, 0],
fwd_rr[:, 1],
fwd_rr[:, 2],
fwd_nn[:, ori, 0],
fwd_nn[:, ori, 1],
fwd_nn[:, ori, 2],
color=color, mode='arrow', scale=1.5e-3)
_plot_forward(renderer, fwd, to_cf_t)

renderer.set_camera(azimuth=90, elevation=90,
distance=0.6, focalpoint=(0., 0., 0.))
Expand Down Expand Up @@ -964,7 +870,149 @@ def _ch_pos_in_coord_frame(info, to_cf_t, warn_meg=True, verbose=None):
return chs['ch_pos'], chs['sources'], chs['detectors']


def _plot_sensors(info, to_cf_t, renderer, picks, meg, eeg, fnirs,
def _plot_head_surface(renderer, head, subject, subjects_dir, bem,
coord_frame, to_cf_t, alpha, color=None):
"""Render a head surface in a 3D scene."""
color = DEFAULTS['coreg']['head_color'] if color is None else color
actor = None
surf = None
if head is not False:
surf = _get_head_surface(head, subject, subjects_dir, bem=bem)
surf = transform_surface_to(
surf, coord_frame, [to_cf_t['mri'], to_cf_t['head']],
copy=True)
actor, _ = renderer.surface(
surface=surf, color=color, opacity=alpha,
backface_culling=False)
return actor, surf


def _plot_axes(renderer, info, to_cf_t, head_mri_t):
"""Render different axes a 3D scene."""
axes = [(to_cf_t['head'], (0.9, 0.3, 0.3))] # always show head
if not np.allclose(head_mri_t['trans'], np.eye(4)): # Show MRI
axes.append((to_cf_t['mri'], (0.6, 0.6, 0.6)))
if pick_types(info, meg=True).size > 0: # Show MEG
axes.append((to_cf_t['meg'], (0., 0.6, 0.6)))
actors = list()
for ax in axes:
x, y, z = np.tile(ax[0]['trans'][:3, 3], 3).reshape((3, 3)).T
u, v, w = ax[0]['trans'][:3, :3]
actor, _ = renderer.sphere(center=np.column_stack((x[0], y[0], z[0])),
color=ax[1], scale=3e-3)
actors.append(actor)
actor, _ = renderer.quiver3d(x=x, y=y, z=z, u=u, v=v, w=w,
mode='arrow', scale=2e-2, color=ax[1],
scale_mode='scalar', resolution=20,
scalars=[0.33, 0.66, 1.0])
actors.append(actor)
return actors


def _plot_head_fiducials(renderer, info, to_cf_t, fid_colors):
defaults = DEFAULTS['coreg']
car_loc = _fiducial_coords(info['dig'], FIFF.FIFFV_COORD_HEAD)
car_loc = apply_trans(to_cf_t['head'], car_loc)
if len(car_loc) == 0:
warn('Digitization points not found. Cannot plot digitization.')
actors = list()
for color, data in zip(fid_colors, car_loc):
actor, _ = renderer.sphere(center=data, color=color,
scale=defaults['dig_fid_scale'],
opacity=defaults['dig_fid_opacity'],
backface_culling=True)
actors.append(actor)
return actors


def _plot_mri_fiducials(renderer, mri_fiducials, subjects_dir, subject,
to_cf_t, fid_colors):
from ..coreg import get_mni_fiducials
defaults = DEFAULTS['coreg']
if mri_fiducials is True:
subjects_dir = get_subjects_dir(subjects_dir, raise_error=True)
if subject is None:
raise ValueError("Subject needs to be specified to "
"automatically find the fiducials file.")
mri_fiducials = op.join(subjects_dir, subject, 'bem',
subject + '-fiducials.fif')
if isinstance(mri_fiducials, str):
if mri_fiducials == 'estimated':
mri_fiducials = get_mni_fiducials(subject, subjects_dir)
else:
mri_fiducials, cf = read_fiducials(mri_fiducials)
if cf != FIFF.FIFFV_COORD_MRI:
raise ValueError("Fiducials are not in MRI space")
fid_loc = _fiducial_coords(mri_fiducials, FIFF.FIFFV_COORD_MRI)
fid_loc = apply_trans(to_cf_t['mri'], fid_loc)
transform = np.eye(4)
transform[:3, :3] = to_cf_t['mri']['trans'][:3, :3] * \
defaults['mri_fid_scale']
# rotate around Z axis 45 deg first
transform = transform @ rotation(0, 0, np.pi / 4)
actors = list()
for color, data in zip(fid_colors, fid_loc):
actor, _ = renderer.quiver3d(
x=data[0], y=data[1], z=data[2],
u=1., v=0., w=0., color=color, mode='oct',
scale=1., opacity=defaults['mri_fid_opacity'],
backface_culling=True, solid_transform=transform)
actors.append(actor)
return actors


def _plot_head_shape_points(renderer, info, to_cf_t):
defaults = DEFAULTS['coreg']
hpi_loc = np.array([
d['r'] for d in (info['dig'] or [])
if (d['kind'] == FIFF.FIFFV_POINT_HPI and
d['coord_frame'] == FIFF.FIFFV_COORD_HEAD)])
hpi_loc = apply_trans(to_cf_t['head'], hpi_loc)
renderer.sphere(center=hpi_loc, color=defaults['hpi_color'],
scale=defaults['hpi_scale'], opacity=0.5,
backface_culling=True)
ext_loc = np.array([
d['r'] for d in (info['dig'] or [])
if (d['kind'] == FIFF.FIFFV_POINT_EXTRA and
d['coord_frame'] == FIFF.FIFFV_COORD_HEAD)])
ext_loc = apply_trans(to_cf_t['head'], ext_loc)
actor, _ = renderer.sphere(center=ext_loc, color=defaults['extra_color'],
scale=defaults['extra_scale'], opacity=0.25,
backface_culling=True)
return actor


def _plot_forward(renderer, fwd, to_cf_t):
from ..forward import Forward
if fwd is not None:
_validate_type(fwd, [Forward])
fwd_rr = fwd['source_rr']
if fwd['source_ori'] == FIFF.FIFFV_MNE_FIXED_ORI:
fwd_nn = fwd['source_nn'].reshape(-1, 1, 3)
else:
fwd_nn = fwd['source_nn'].reshape(-1, 3, 3)
# update coordinate frame
fwd_trans = to_cf_t[_frame_to_str[fwd['coord_frame']]]
fwd_rr = apply_trans(fwd_trans, fwd_rr)
fwd_nn = apply_trans(fwd_trans, fwd_nn, move=False)
red = (1.0, 0.0, 0.0)
green = (0.0, 1.0, 0.0)
blue = (0.0, 0.0, 1.0)
actors = list()
for ori, color in zip(range(fwd_nn.shape[1]), (red, green, blue)):
actor, _ = renderer.quiver3d(
fwd_rr[:, 0],
fwd_rr[:, 1],
fwd_rr[:, 2],
fwd_nn[:, ori, 0],
fwd_nn[:, ori, 1],
fwd_nn[:, ori, 2],
color=color, mode='arrow', scale=1.5e-3)
actors.append(actor)
return actors


def _plot_sensors(renderer, info, to_cf_t, picks, meg, eeg, fnirs,
warn_meg, head_surf, units):
"""Render sensors in a 3D scene."""
defaults = DEFAULTS['coreg']
Expand Down
2 changes: 1 addition & 1 deletion mne/viz/_brain/_brain.py
Original file line number Diff line number Diff line change
Expand Up @@ -2633,7 +2633,7 @@ def add_sensors(self, info, trans, meg=None, eeg='original', fnirs=True,
self._renderer.subplot(ri, ci)
if picks.size > 0:
sensors_actors = _plot_sensors(
info, to_cf_t, self._renderer, picks, meg, eeg,
self._renderer, info, to_cf_t, picks, meg, eeg,
fnirs, head_surf, warn_meg, self._units)
for item, actors in sensors_actors.items():
for actor in actors:
Expand Down
2 changes: 1 addition & 1 deletion tutorials/forward/35_eeg_no_mri.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
for ch_name in raw.ch_names)
raw.rename_channels(new_names)

# Read and set the EEG electrode locations
# Read and set the EEG electrode locations:
montage = mne.channels.make_standard_montage('standard_1005')
raw.set_montage(montage)
raw.set_eeg_reference(projection=True) # needed for inverse modeling
Expand Down

0 comments on commit 3e7914e

Please sign in to comment.