Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MAINT, MRG] Refactor BEM code #9625

Merged
merged 3 commits into from
Jul 29, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 46 additions & 22 deletions mne/_freesurfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,44 +57,68 @@ def _import_nibabel(why='use MRI files'):
return nib


def _mri_orientation(img, orientation):
"""Get MRI orientation information from an image.
def _reorient_image(img, axcodes='RAS'):
"""Reorient an image to a given orientation.

Parameters
----------
img : instance of SpatialImage
The MRI image.
axcodes : tuple | str
The axis codes specifying the orientation, e.g. "RAS".
See :func:`nibabel.orientations.aff2axcodes`.

Returns
-------
img_data : ndarray
The reoriented image data.
vox_ras_t : ndarray
The new transform from the new voxels to surface RAS.

Notes
-----
.. versionadded:: 0.24
"""
import nibabel as nib
orig_data = np.array(img.dataobj).astype(np.float32)
# reorient data to RAS
ornt = nib.orientations.axcodes2ornt(
nib.orientations.aff2axcodes(img.affine)).astype(int)
ras_ornt = nib.orientations.axcodes2ornt(axcodes)
ornt_trans = nib.orientations.ornt_transform(ornt, ras_ornt)
img_data = nib.orientations.apply_orientation(orig_data, ornt_trans)
larsoner marked this conversation as resolved.
Show resolved Hide resolved
orig_mgh = nib.MGHImage(orig_data, img.affine)
aff_trans = nib.orientations.inv_ornt_aff(ornt_trans, img.shape)
vox_ras_t = np.dot(orig_mgh.header.get_vox2ras_tkr(), aff_trans)
return img_data, vox_ras_t


def _mri_orientation(orientation):
"""Get MRI orientation information from an image.

Parameters
----------
orientation : str
Orientation that you want. Can be "axial", "saggital", or "coronal".

Returns
-------
xyz : tuple, shape (3,)
The dimension indices for X, Y, and Z.
flips : tuple, shape (3,)
Whether each dimension requires a flip.
order : tuple, shape (3,)
The resulting order of the data if the given ``xyz`` and ``flips``
are used.
axis : int
The dimension of the axis to take slices over when plotting.
x : int
The dimension of the x axis.
y : int
The dimension of the y axis.

Notes
-----
.. versionadded:: 0.21
.. versionchanged:: 0.24
"""
import nibabel as nib
_validate_type(img, nib.spatialimages.SpatialImage)
_check_option('orientation', orientation, ('coronal', 'axial', 'sagittal'))
axcodes = ''.join(nib.orientations.aff2axcodes(img.affine))
flips = {o: (1 if o in axcodes else -1) for o in 'RAS'}
axcodes = axcodes.replace('L', 'R').replace('P', 'A').replace('I', 'S')
order = dict(
coronal=('R', 'S', 'A'),
axial=('R', 'A', 'S'),
sagittal=('A', 'S', 'R'),
)[orientation]
xyz = tuple(axcodes.index(c) for c in order)
flips = tuple(flips[c] for c in order)
return xyz, flips, order
axis = dict(coronal=1, axial=2, sagittal=0)[orientation]
x, y = sorted(set([0, 1, 2]).difference(set([axis])))
return axis, x, y


def _get_mri_info_data(mri, data):
Expand Down
12 changes: 5 additions & 7 deletions mne/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from .io import read_raw, read_info
from .io._read_raw import supported as extension_reader_map
from .io.pick import _DATA_CH_TYPES_SPLIT
from ._freesurfer import _mri_orientation
from ._freesurfer import _reorient_image, _mri_orientation
from .utils import (logger, verbose, get_subjects_dir, warn, _ensure_int,
fill_doc, _check_option, _validate_type, _safe_input)
from .viz import (plot_events, plot_alignment, plot_cov, plot_projs_topomap,
Expand Down Expand Up @@ -150,12 +150,9 @@ def _figs_to_mrislices(sl, n_jobs, **kwargs):
parallel, p_fun, _ = parallel_func(_plot_mri_contours, use_jobs)
outs = parallel(p_fun(slices=s, **kwargs)
for s in np.array_split(sl, use_jobs))
# deal with flip_z
flip_z = 1
out = list()
for o in outs:
o, flip_z = o
out.extend(o[::flip_z])
out.extend(o)
return out


Expand Down Expand Up @@ -1875,8 +1872,9 @@ def _render_one_bem_axis(self, mri_fname, surfaces, global_id,
"""Render one axis of bem contours (only PNG)."""
import nibabel as nib
nim = nib.load(mri_fname)
(_, _, z), _, _ = _mri_orientation(nim, orientation)
n_slices = nim.shape[z]
data = _reorient_image(nim)[0]
axis = _mri_orientation(orientation)[0]
n_slices = data.shape[axis]

name = orientation
html = []
Expand Down
55 changes: 21 additions & 34 deletions mne/viz/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,16 @@
import numpy as np

from ..defaults import DEFAULTS
from ..fixes import _get_img_fdata
from .._freesurfer import _mri_orientation, _read_mri_info, _check_mri
from .._freesurfer import (_reorient_image, _read_mri_info, _check_mri,
_mri_orientation)
from ..rank import compute_rank
from ..surface import read_surface
from ..io.constants import FIFF
from ..io.proj import make_projector
from ..io.pick import (_DATA_CH_TYPES_SPLIT, pick_types, pick_info,
pick_channels)
from ..source_space import read_source_spaces, SourceSpaces, _ensure_src
from ..transforms import invert_transform, apply_trans, _frame_to_str
from ..transforms import apply_trans, _frame_to_str
from ..utils import (logger, verbose, warn, _check_option, get_subjects_dir,
_mask_to_onsets_offsets, _pl, _on_missing, fill_doc)
from ..io.pick import _picks_by_type
Expand Down Expand Up @@ -313,20 +313,14 @@ def _plot_mri_contours(mri_fname, surfaces, src, orientation='coronal',
_check_option('orientation', orientation, ('coronal', 'axial', 'sagittal'))

# Load the T1 data
_, vox_mri_t, _, _, _, nim = _read_mri_info(
_, _, _, _, _, nim = _read_mri_info(
mri_fname, units='mm', return_img=True)
mri_vox_t = invert_transform(vox_mri_t)['trans']
del vox_mri_t

# plot axes (x, y, z) as data axes
(x, y, z), (flip_x, flip_y, flip_z), order = _mri_orientation(
nim, orientation)
transpose = x < y

data = _get_img_fdata(nim)
shift_x = data.shape[x] if flip_x < 0 else 0
shift_y = data.shape[y] if flip_y < 0 else 0
n_slices = data.shape[z]

data, rasvox_mri_t = _reorient_image(nim)
mri_rasvox_t = np.linalg.inv(rasvox_mri_t)
axis, x, y = _mri_orientation(orientation)

n_slices = data.shape[axis]
if slices is None:
slices = np.round(np.linspace(0, n_slices - 1, 14)).astype(int)[1:-1]
slices = np.atleast_1d(slices).copy()
Expand All @@ -337,17 +331,14 @@ def _plot_mri_contours(mri_fname, surfaces, src, orientation='coronal',
raise ValueError('slices must be a sorted 1D array of int with unique '
'elements, at least one element, and no elements '
'greater than %d, got %s' % (n_slices - 1, slices))
if flip_z < 0:
# Proceed in the opposite order to maintain left-to-right / orientation
slices = slices[::-1]

# create of list of surfaces
surfs = list()
for file_name, color in surfaces:
surf = dict()
surf['rr'], surf['tris'] = read_surface(file_name)
# move surface to voxel coordinate system
surf['rr'] = apply_trans(mri_vox_t, surf['rr'])
surf['rr'] = apply_trans(mri_rasvox_t, surf['rr'])
surfs.append((surf, color))

sources = list()
Expand All @@ -360,7 +351,7 @@ def _plot_mri_contours(mri_fname, surfaces, src, orientation='coronal',
f'{_frame_to_str[src[0]["coord_frame"]]}')
for src_ in src:
points = src_['rr'][src_['inuse'].astype(bool)]
sources.append(apply_trans(mri_vox_t, points * 1e3))
sources.append(apply_trans(mri_rasvox_t, points * 1e3))
sources = np.concatenate(sources, axis=0)

if img_output:
Expand All @@ -384,17 +375,15 @@ def _plot_mri_contours(mri_fname, surfaces, src, orientation='coronal',
[[-np.inf], slices[:-1] + np.diff(slices) / 2., [np.inf]]) # float
slicer = [slice(None)] * 3
ori_labels = dict(R='LR', A='PA', S='IS')
xlabels, ylabels = ori_labels[order[0]], ori_labels[order[1]]
xlabels, ylabels = ori_labels['RAS'[x]], ori_labels['RAS'[y]]
path_effects = [patheffects.withStroke(linewidth=4, foreground="k",
alpha=0.75)]
out = list() if img_output else fig
for ai, (ax, sl, lower, upper) in enumerate(zip(
axs, slices, bounds[:-1], bounds[1:])):
# adjust the orientations for good view
slicer[z] = sl
dat = data[tuple(slicer)]
dat = dat.T if transpose else dat
dat = dat[::flip_y, ::flip_x]
slicer[axis] = sl
dat = data[tuple(slicer)].T

# First plot the anatomical data
if img_output:
Expand All @@ -408,16 +397,14 @@ def _plot_mri_contours(mri_fname, surfaces, src, orientation='coronal',
for surf, color in surfs:
with warnings.catch_warnings(record=True): # ignore contour warn
warnings.simplefilter('ignore')
ax.tricontour(flip_x * surf['rr'][:, x] + shift_x,
flip_y * surf['rr'][:, y] + shift_y,
surf['tris'], surf['rr'][:, z],
ax.tricontour(surf['rr'][:, x], surf['rr'][:, y],
surf['tris'], surf['rr'][:, axis],
levels=[sl], colors=color, linewidths=1.0,
zorder=1)

if len(sources):
in_slice = (sources[:, z] >= lower) & (sources[:, z] < upper)
ax.scatter(flip_x * sources[in_slice, x] + shift_x,
flip_y * sources[in_slice, y] + shift_y,
in_slice = (sources[:, axis] >= lower) & (sources[:, axis] < upper)
ax.scatter(sources[in_slice, x], sources[in_slice, y],
marker='.', color='#FF00FF', s=1, zorder=2)
if show_indices:
ax.text(dat.shape[1] // 8 + 0.5, 0.5, str(sl),
Expand Down Expand Up @@ -448,7 +435,7 @@ def _plot_mri_contours(mri_fname, surfaces, src, orientation='coronal',
fig.subplots_adjust(left=0., bottom=0., right=1., top=1., wspace=0.,
hspace=0.)
plt_show(show, fig=fig)
return out, flip_z
return out


def plot_bem(subject=None, subjects_dir=None, orientation='coronal',
Expand Down Expand Up @@ -559,7 +546,7 @@ def plot_bem(subject=None, subjects_dir=None, orientation='coronal',

# Plot the contours
return _plot_mri_contours(mri_fname, surfaces, src, orientation, slices,
show, show_indices, show_orientation)[0]
show, show_indices, show_orientation)
larsoner marked this conversation as resolved.
Show resolved Hide resolved


def _get_bem_plotting_surfaces(bem_path):
Expand Down
7 changes: 4 additions & 3 deletions tutorials/forward/30_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@
data_path = sample.data_path()

# the raw file containing the channel location + types
raw_fname = data_path + '/MEG/sample/sample_audvis_raw.fif'
sample_dir = op.join(data_path, 'MEG', 'sample',)
raw_fname = op.join(sample_dir, 'sample_audvis_raw.fif')
# The paths to Freesurfer reconstructions
subjects_dir = data_path + '/subjects'
subjects_dir = op.join(data_path, 'subjects')
subject = 'sample'

# %%
Expand Down Expand Up @@ -75,7 +76,7 @@
# alignment with the following code.

# The transformation file obtained by coregistration
trans = data_path + '/MEG/sample/sample_audvis_raw-trans.fif'
trans = op.join(sample_dir, 'sample_audvis_raw-trans.fif')

info = mne.io.read_info(raw_fname)
# Here we look at the dense head, which isn't used for BEM computations but
Expand Down
Loading