Skip to content

Commit

Permalink
Merge branch 'main' into efficient
Browse files Browse the repository at this point in the history
  • Loading branch information
larsoner authored Jan 10, 2024
2 parents be3c83e + d2c806c commit d11363c
Show file tree
Hide file tree
Showing 6 changed files with 200 additions and 41 deletions.
1 change: 1 addition & 0 deletions doc/changes/devel/12336.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Allow :meth:`mne.io.Raw.interpolate_bads` and :meth:`mne.Epochs.interpolate_bads` to work on ``ecog`` and ``seeg`` data; for ``seeg`` data a spline is fit to neighboring electrode contacts on the same shaft, by `Alex Rockhill`_
38 changes: 27 additions & 11 deletions mne/channels/channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -870,9 +870,12 @@ def interpolate_bads(
.. versionadded:: 0.9.0
"""
from .interpolation import (
_interpolate_bads_ecog,
_interpolate_bads_eeg,
_interpolate_bads_meeg,
_interpolate_bads_nan,
_interpolate_bads_nirs,
_interpolate_bads_seeg,
)

_check_preload(self, "interpolation")
Expand All @@ -894,35 +897,48 @@ def interpolate_bads(
"eeg": ("spline", "MNE", "nan"),
"meg": ("MNE", "nan"),
"fnirs": ("nearest", "nan"),
"ecog": ("spline", "nan"),
"seeg": ("spline", "nan"),
}
for key in method:
_check_option("method[key]", key, ("meg", "eeg", "fnirs"))
_check_option("method[key]", key, tuple(valids))
_check_option(f"method['{key}']", method[key], valids[key])
logger.info("Setting channel interpolation method to %s.", method)
idx = _picks_to_idx(self.info, list(method), exclude=(), allow_empty=True)
if idx.size == 0 or len(pick_info(self.info, idx)["bads"]) == 0:
warn("No bad channels to interpolate. Doing nothing...")
return self
for ch_type in method.copy():
idx = _picks_to_idx(self.info, ch_type, exclude=(), allow_empty=True)
if len(pick_info(self.info, idx)["bads"]) == 0:
method.pop(ch_type)
logger.info("Interpolating bad channels.")
origin = _check_origin(origin, self.info)
needs_origin = [key != "seeg" and val != "nan" for key, val in method.items()]
if any(needs_origin):
origin = _check_origin(origin, self.info)
for ch_type, interp in method.items():
if interp == "nan":
_interpolate_bads_nan(self, ch_type, exclude=exclude)
if method.get("eeg", "") == "spline":
_interpolate_bads_eeg(self, origin=origin, exclude=exclude)
eeg_mne = False
elif "eeg" not in method:
eeg_mne = False
else:
eeg_mne = True
if "meg" in method or eeg_mne:
meg_mne = method.get("meg", "") == "MNE"
eeg_mne = method.get("eeg", "") == "MNE"
if meg_mne or eeg_mne:
_interpolate_bads_meeg(
self,
mode=mode,
origin=origin,
meg=meg_mne,
eeg=eeg_mne,
origin=origin,
exclude=exclude,
method=method,
)
if "fnirs" in method:
_interpolate_bads_nirs(self, exclude=exclude, method=method["fnirs"])
if method.get("fnirs", "") == "nearest":
_interpolate_bads_nirs(self, exclude=exclude)
if method.get("ecog", "") == "spline":
_interpolate_bads_ecog(self, origin=origin, exclude=exclude)
if method.get("seeg", "") == "spline":
_interpolate_bads_seeg(self, exclude=exclude)

if reset_bads is True:
if "nan" in method.values():
Expand Down
142 changes: 116 additions & 26 deletions mne/channels/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@

import numpy as np
from numpy.polynomial.legendre import legval
from scipy.interpolate import RectBivariateSpline
from scipy.linalg import pinv
from scipy.spatial.distance import pdist, squareform

from .._fiff.meas_info import _simplify_info
from .._fiff.pick import pick_channels, pick_info, pick_types
from ..surface import _normalize_vectors
from ..utils import _check_option, _validate_type, logger, verbose, warn
from ..utils import _validate_type, logger, verbose, warn


def _calc_h(cosang, stiffness=4, n_legendre_terms=50):
Expand Down Expand Up @@ -132,13 +133,13 @@ def _do_interp_dots(inst, interpolation, goods_idx, bads_idx):


@verbose
def _interpolate_bads_eeg(inst, origin, exclude=None, verbose=None):
def _interpolate_bads_eeg(inst, origin, exclude=None, ecog=False, verbose=None):
if exclude is None:
exclude = list()
bads_idx = np.zeros(len(inst.ch_names), dtype=bool)
goods_idx = np.zeros(len(inst.ch_names), dtype=bool)

picks = pick_types(inst.info, meg=False, eeg=True, exclude=exclude)
picks = pick_types(inst.info, meg=False, eeg=not ecog, ecog=ecog, exclude=exclude)
inst.info._check_consistency()
bads_idx[picks] = [inst.ch_names[ch] in inst.info["bads"] for ch in picks]

Expand Down Expand Up @@ -172,6 +173,11 @@ def _interpolate_bads_eeg(inst, origin, exclude=None, verbose=None):
_do_interp_dots(inst, interpolation, goods_idx, bads_idx)


@verbose
def _interpolate_bads_ecog(inst, origin, exclude=None, verbose=None):
_interpolate_bads_eeg(inst, origin, exclude=exclude, ecog=True, verbose=verbose)


def _interpolate_bads_meg(
inst, mode="accurate", origin=(0.0, 0.0, 0.04), verbose=None, ref_meg=False
):
Expand All @@ -180,6 +186,26 @@ def _interpolate_bads_meg(
)


@verbose
def _interpolate_bads_nan(
inst,
ch_type,
ref_meg=False,
exclude=(),
*,
verbose=None,
):
info = _simplify_info(inst.info)
picks_type = pick_types(info, ref_meg=ref_meg, exclude=exclude, **{ch_type: True})
use_ch_names = [inst.info["ch_names"][p] for p in picks_type]
bads_type = [ch for ch in inst.info["bads"] if ch in use_ch_names]
if len(bads_type) == 0 or len(picks_type) == 0:
return
# select the bad channels to be interpolated
picks_bad = pick_channels(inst.info["ch_names"], bads_type, exclude=[])
inst._data[..., picks_bad, :] = np.nan


@verbose
def _interpolate_bads_meeg(
inst,
Expand Down Expand Up @@ -213,10 +239,6 @@ def _interpolate_bads_meeg(
# select the bad channels to be interpolated
picks_bad = pick_channels(inst.info["ch_names"], bads_type, exclude=[])

if method[ch_type] == "nan":
inst._data[picks_bad] = np.nan
continue

# do MNE based interpolation
if ch_type == "eeg":
picks_to = picks_type
Expand All @@ -232,7 +254,7 @@ def _interpolate_bads_meeg(


@verbose
def _interpolate_bads_nirs(inst, method="nearest", exclude=(), verbose=None):
def _interpolate_bads_nirs(inst, exclude=(), verbose=None):
from mne.preprocessing.nirs import _validate_nirs_info

if len(pick_types(inst.info, fnirs=True, exclude=())) == 0:
Expand All @@ -251,25 +273,93 @@ def _interpolate_bads_nirs(inst, method="nearest", exclude=(), verbose=None):
chs = [inst.info["chs"][i] for i in picks_nirs]
locs3d = np.array([ch["loc"][:3] for ch in chs])

_check_option("fnirs_method", method, ["nearest", "nan"])

if method == "nearest":
dist = pdist(locs3d)
dist = squareform(dist)

for bad in picks_bad:
dists_to_bad = dist[bad]
# Ignore distances to self
dists_to_bad[dists_to_bad == 0] = np.inf
# Ignore distances to other bad channels
dists_to_bad[bads_mask] = np.inf
# Find closest remaining channels for same frequency
closest_idx = np.argmin(dists_to_bad) + (bad % 2)
inst._data[bad] = inst._data[closest_idx]
else:
assert method == "nan"
inst._data[picks_bad] = np.nan
dist = pdist(locs3d)
dist = squareform(dist)

for bad in picks_bad:
dists_to_bad = dist[bad]
# Ignore distances to self
dists_to_bad[dists_to_bad == 0] = np.inf
# Ignore distances to other bad channels
dists_to_bad[bads_mask] = np.inf
# Find closest remaining channels for same frequency
closest_idx = np.argmin(dists_to_bad) + (bad % 2)
inst._data[bad] = inst._data[closest_idx]

# TODO: this seems like a bug because it does not respect reset_bads
inst.info["bads"] = [ch for ch in inst.info["bads"] if ch in exclude]

return inst


def _find_seeg_electrode_shaft(pos, tol=2e-3):
# 1) find nearest neighbor to define the electrode shaft line
# 2) find all contacts on the same line

dist = squareform(pdist(pos))
np.fill_diagonal(dist, np.inf)

shafts = list()
for i, n1 in enumerate(pos):
if any([i in shaft for shaft in shafts]):
continue
n2 = pos[np.argmin(dist[i])] # 1
# https://mathworld.wolfram.com/Point-LineDistance3-Dimensional.html
shaft_dists = np.linalg.norm(
np.cross((pos - n1), (pos - n2)), axis=1
) / np.linalg.norm(n2 - n1)
shafts.append(np.where(shaft_dists < tol)[0]) # 2
return shafts


@verbose
def _interpolate_bads_seeg(inst, exclude=None, tol=2e-3, verbose=None):
if exclude is None:
exclude = list()
picks = pick_types(inst.info, meg=False, seeg=True, exclude=exclude)
inst.info._check_consistency()
bads_idx = np.isin(np.array(inst.ch_names)[picks], inst.info["bads"])

if len(picks) == 0 or bads_idx.sum() == 0:
return

pos = inst._get_channel_positions(picks)

# Make sure only sEEG are used
bads_idx_pos = bads_idx[picks]

shafts = _find_seeg_electrode_shaft(pos, tol=tol)

# interpolate the bad contacts
picks_bad = list(np.where(bads_idx_pos)[0])
for shaft in shafts:
bads_shaft = np.array([idx for idx in picks_bad if idx in shaft])
if bads_shaft.size == 0:
continue
goods_shaft = shaft[np.isin(shaft, bads_shaft, invert=True)]
if goods_shaft.size < 2:
raise RuntimeError(
f"{goods_shaft.size} good contact(s) found in a line "
f" with {np.array(inst.ch_names)[bads_shaft]}, "
"at least 2 are required for interpolation. "
"Dropping this channel/these channels is recommended."
)
logger.debug(
f"Interpolating {np.array(inst.ch_names)[bads_shaft]} using "
f"data from {np.array(inst.ch_names)[goods_shaft]}"
)
bads_shaft_idx = np.where(np.isin(shaft, bads_shaft))[0]
goods_shaft_idx = np.where(~np.isin(shaft, bads_shaft))[0]
n1, n2 = pos[shaft][:2]
ts = np.array(
[
-np.dot(n1 - n0, n2 - n1) / np.linalg.norm(n2 - n1) ** 2
for n0 in pos[shaft]
]
)
if np.any(np.diff(ts) < 0):
ts *= -1
y = np.arange(inst._data.shape[-1])
inst._data[bads_shaft] = RectBivariateSpline(
x=ts[goods_shaft_idx], y=y, z=inst._data[goods_shaft]
)(x=ts[bads_shaft_idx], y=y) # 3
50 changes: 50 additions & 0 deletions mne/channels/tests/test_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from mne import Epochs, pick_channels, pick_types, read_events
from mne._fiff.constants import FIFF
from mne._fiff.proj import _has_eeg_average_ref_proj
from mne.channels import make_dig_montage
from mne.channels.interpolation import _make_interpolation_matrix
from mne.datasets import testing
from mne.io import RawArray, read_raw_ctf, read_raw_fif, read_raw_nirx
Expand Down Expand Up @@ -329,6 +330,55 @@ def test_interpolation_nirs():
assert raw_haemo.info["bads"] == []


@testing.requires_testing_data
def test_interpolation_ecog():
"""Test interpolation for ECoG."""
raw, epochs_eeg = _load_data("eeg")
bads = ["EEG 012"]
bads_mask = np.isin(epochs_eeg.ch_names, bads)

epochs_ecog = epochs_eeg.set_channel_types(
{ch: "ecog" for ch in epochs_eeg.ch_names}
)
epochs_ecog.info["bads"] = bads

# check that interpolation changes the data in raw
raw_ecog = RawArray(data=epochs_ecog._data[0], info=epochs_ecog.info)
raw_before = raw_ecog.copy()
raw_after = raw_ecog.interpolate_bads(method=dict(ecog="spline"))
assert not np.all(raw_before._data[bads_mask] == raw_after._data[bads_mask])
assert_array_equal(raw_before._data[~bads_mask], raw_after._data[~bads_mask])


@testing.requires_testing_data
def test_interpolation_seeg():
"""Test interpolation for sEEG."""
raw, epochs_eeg = _load_data("eeg")
bads = ["EEG 012"]
bads_mask = np.isin(epochs_eeg.ch_names, bads)
epochs_seeg = epochs_eeg.set_channel_types(
{ch: "seeg" for ch in epochs_eeg.ch_names}
)
epochs_seeg.info["bads"] = bads

# check that interpolation changes the data in raw
raw_seeg = RawArray(data=epochs_seeg._data[0], info=epochs_seeg.info)
raw_before = raw_seeg.copy()
with pytest.raises(RuntimeError, match="1 good contact"):
raw_seeg.interpolate_bads(method=dict(seeg="spline"))
montage = raw_seeg.get_montage()
pos = montage.get_positions()
ch_pos = pos.pop("ch_pos")
n0 = ch_pos[epochs_seeg.ch_names[0]]
n1 = ch_pos[epochs_seeg.ch_names[1]]
for i, ch in enumerate(epochs_seeg.ch_names[2:]):
ch_pos[ch] = n0 + (n1 - n0) * (i + 2)
raw_seeg.set_montage(make_dig_montage(ch_pos, **pos))
raw_after = raw_seeg.interpolate_bads(method=dict(seeg="spline"))
assert not np.all(raw_before._data[bads_mask] == raw_after._data[bads_mask])
assert_array_equal(raw_before._data[~bads_mask], raw_after._data[~bads_mask])


def test_nan_interpolation(raw):
"""Test 'nan' method for interpolating bads."""
ch_to_interp = [raw.ch_names[1]] # don't use channel 0 (type is IAS not MEG)
Expand Down
4 changes: 3 additions & 1 deletion mne/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,9 @@
combine_xyz="fro",
allow_fixed_depth=True,
),
interpolation_method=dict(eeg="spline", meg="MNE", fnirs="nearest"),
interpolation_method=dict(
eeg="spline", meg="MNE", fnirs="nearest", ecog="spline", seeg="spline"
),
volume_options=dict(
alpha=None,
resolution=1.0,
Expand Down
6 changes: 3 additions & 3 deletions mne/epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2977,11 +2977,11 @@ def _ensure_list(x):
*last_cols,
]

data = np.empty((len(events_df), len(columns)))
data = np.empty((len(events_df), len(columns)), float)
metadata = pd.DataFrame(data=data, columns=columns, index=events_df.index)

# Event names
metadata.iloc[:, 0] = ""
metadata["event_name"] = ""

# Event times
start_idx = 1
Expand All @@ -2990,7 +2990,7 @@ def _ensure_list(x):

# keep_first and keep_last names
start_idx = stop_idx
metadata.iloc[:, start_idx:] = None
metadata[columns[start_idx:]] = ""

# We're all set, let's iterate over all events and fill in in the
# respective cells in the metadata. We will subset this to include only
Expand Down

0 comments on commit d11363c

Please sign in to comment.