Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/vferat/pycrostates
Browse files Browse the repository at this point in the history
  • Loading branch information
Mathieu Scheltienne committed May 10, 2023
2 parents 397af76 + 9e83ccb commit bd412cc
Show file tree
Hide file tree
Showing 10 changed files with 545 additions and 6 deletions.
10 changes: 10 additions & 0 deletions docs/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,16 @@ @article{MICHEL2018577
year = {2018}
}

@article{michel2019eeg,
author = {Michel, Christoph M and Brunet, Denis},
doi = {10.3389/fneur.2019.00325},
journal = {Frontiers in neurology},
pages = {325},
title = {EEG source imaging: a practical review of the analysis steps},
volume = {10},
year = {2019}
}

@article{Murray2008,
author = {Murray, Micah M. and Brunet, Denis and Michel, Christoph M.},
doi = {10.1007/s10548-008-0054-5},
Expand Down
1 change: 1 addition & 0 deletions docs/source/api/preprocessing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,6 @@ Preprocessing
.. autosummary::
:toctree: generated/

apply_spatial_filter
extract_gfp_peaks
resample
6 changes: 4 additions & 2 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,8 @@
# -- intersphinx -------------------------------------------------------------
intersphinx_mapping = {
"python": ("https://docs.python.org/3", None),
"numpy": ("https://numpy.org/devdocs", None),
"scipy": ("https://scipy.github.io/devdocs", None),
"numpy": ("https://numpy.org/doc/stable", None),
"scipy": ("https://docs.scipy.org/doc/scipy", None),
"matplotlib": ("https://matplotlib.org", None),
"mne": ("https://mne.tools/stable/", None),
"joblib": ("https://joblib.readthedocs.io/en/latest", None),
Expand Down Expand Up @@ -171,6 +171,8 @@
"Axes": "matplotlib.axes.Axes",
"colormap": ":doc:`colormap <matplotlib:tutorials/colors/colormaps>`",
"Figure": "matplotlib.figure.Figure",
# Scipy
"csr_matrix": "scipy.sparse.csr_matrix",
}
numpydoc_xref_ignore = {
"instance",
Expand Down
4 changes: 2 additions & 2 deletions pycrostates/cluster/kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from numpy.random import Generator, RandomState
from numpy.typing import NDArray

from .._typing import Picks, RANDomState
from .._typing import CHData, Picks, RANDomState
from ..utils import _corr_vectors
from ..utils._checks import _check_n_jobs, _check_random_state, _check_type
from ..utils._docs import copy_doc, fill_doc
Expand Down Expand Up @@ -137,7 +137,7 @@ def _check_fit(self):
@fill_doc
def fit(
self,
inst: Union[BaseRaw, BaseEpochs],
inst: Union[BaseRaw, BaseEpochs, CHData],
picks: Picks = "eeg",
tmin: Optional[Union[int, float]] = None,
tmax: Optional[Union[int, float]] = None,
Expand Down
34 changes: 34 additions & 0 deletions pycrostates/io/ch_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,30 @@ def pick(self, picks, exclude="bads"):
self._data = data
return self

def _get_channel_positions(self, picks=None):
"""Get channel locations from info.
Parameters
----------
picks : str | list | slice | None
None selects the good data channels.
Returns
-------
pos : array of shape (n_channels, 3)
Channel X/Y/Z locations.
"""
picks = _picks_to_idx(self.info, picks)
chs = self.info["chs"]
pos = np.array([chs[k]["loc"][:3] for k in picks])
n_zero = np.sum(np.sum(np.abs(pos), axis=1) == 0)
if n_zero > 1: # XXX some systems have origin (0, 0, 0)
raise ValueError(
"Could not extract channel positions for "
f"{n_zero} channels."
)
return pos

# --------------------------------------------------------------------
@property
def info(self) -> CHInfo:
Expand All @@ -146,3 +170,13 @@ def info(self) -> CHInfo:
:type: ChInfo
"""
return self._info

@property
def ch_names(self):
"""Channel names."""
return self.info["ch_names"]

@property
def preload(self):
"""Preload required by some MNE functions."""
return True
12 changes: 11 additions & 1 deletion pycrostates/io/tests/test_ch_data.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np
import pytest
from mne import create_info
from mne import create_info, pick_types

from pycrostates.io import ChData, ChInfo

Expand Down Expand Up @@ -114,3 +114,13 @@ def test_ChData_invalid_arguments():
ChData(data.reshape(6, 5, 400), ch_info)
with pytest.raises(ValueError, match="'data' and 'info' do not have"):
ChData(data, create_info(2, 400, "eeg"))


def test_ChData_get_channel_positions():
"""Test get for sensor positions."""
ch_data = ChData(data, ch_info_types.copy())
ch_data.set_montage("standard_1020")
picks = pick_types(ch_data.info, meg=False, eeg=True)
pos = np.array([ch["loc"][:3] for ch in ch_data.info["chs"]])[picks]
ch_data_pos = ch_data._get_channel_positions(picks=picks)
assert np.all(ch_data_pos == pos)
3 changes: 2 additions & 1 deletion pycrostates/preprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,6 @@

from .extract_gfp_peaks import extract_gfp_peaks
from .resample import resample
from .spatial_filter import apply_spatial_filter

__all__ = ("extract_gfp_peaks", "resample")
__all__ = ("apply_spatial_filter", "extract_gfp_peaks", "resample")
232 changes: 232 additions & 0 deletions pycrostates/preprocessing/spatial_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
from typing import Union

import numpy as np
from mne import BaseEpochs, pick_info
from mne.bem import _check_origin
from mne.channels import find_ch_adjacency
from mne.channels.interpolation import _make_interpolation_matrix
from mne.io import BaseRaw
from mne.io.pick import _picks_by_type
from mne.parallel import parallel_func
from mne.utils.check import _check_preload
from numpy.typing import NDArray
from scipy.sparse import csr_matrix

from .._typing import CHData
from ..utils._checks import _check_n_jobs, _check_type, _check_value
from ..utils._docs import fill_doc
from ..utils._logs import logger, verbose


def _check_adjacency(adjacency, info, ch_type):
"""Check adjacency matrix."""
_check_type(adjacency, (csr_matrix, np.ndarray, str), "adjacency")
# auto
if isinstance(adjacency, str):
if adjacency != "auto":
raise (
ValueError(
f"Adjacency can be either a scipy.sparse.csr_matrix"
f"or 'auto' but got string '{adjacency}' instead."
)
)
adjacency, ch_names = find_ch_adjacency(info, ch_type)
# custom
if isinstance(adjacency, csr_matrix):
adjacency = adjacency.toarray()
if isinstance(adjacency, np.ndarray):
ch_names = info.ch_names
n_channels = len(ch_names)
if adjacency.ndim != 2:
raise ValueError(
"Adjacency must have exactly 2 dimensions but got "
f"{adjacency.ndim} dimensions instead."
)
if (adjacency.shape[0] != n_channels) or (
adjacency.shape[1] != n_channels
):
raise ValueError(
"Adjacency must be of shape (n_channels, n_channels) "
f"but got {adjacency.shape} instead."
)
if not np.array_equal(adjacency, adjacency.astype(bool)):
raise ValueError(
"Values contained in adjacency can only be 0 or 1."
)
return (adjacency, ch_names)


@fill_doc
@verbose
def apply_spatial_filter(
inst: Union[BaseRaw, BaseEpochs, CHData],
ch_type: str = "eeg",
exclude_bads: bool = True,
origin: Union[str, NDArray[float]] = "auto",
adjacency: Union[csr_matrix, str] = "auto",
n_jobs: int = 1,
verbose=None,
):
r"""Apply a spatial filter.
Adapted from \ :footcite:t:`michel2019eeg`.
Apply an instantaneous filter which interpolates channels
with local neighbors while removing outliers.
The current implementation proceeds as follows:
* An interpolation matrix is computed using
mne.channels.interpolation._make_interpolation_matrix.
* An ajdacency matrix is computed using
`mne.channels.find_ch_adjacency`.
* If ``exclude_bads`` is set to ``True``,
bad channels are removed from the ajdacency matrix.
* For each timepoint and each channel,
a reduced adjacency matrix is built by removing neighbors
with lowest and highest value.
* For each timepoint and each channel,
a reduced interpolation matrix is built by extracting neighbor
weights based on the reduced adjacency matrix.
* The reduced interpolation matrices are normalized.
* The channel's timepoints are interpolated
using their reduced interpolation matrix.
Parameters
----------
inst : Raw | Epochs | ChData
Instance to filter spatially.
ch_type : str
The channel type on which to apply the spatial filter.
Currently only supports ``'eeg'``.
exclude_bads : bool
If set to ``True``, bad channels will be removed
from the adjacency matrix and therefore not used
to interpolate neighbors. In addition, bad channels
will not be filtered.
If set to ``False``, proceed as if all channels were good.
origin : array of shape (3,) | str
Origin of the sphere in the head coordinate frame and in meters.
Can be ``'auto'`` (default), which means a head-digitization-based
origin fit.
adjacency : array or csr_matrix of shape (n_channels, n_channels) | str
An adjacency matrix. Can be created using
`mne.channels.find_ch_adjacency` and edited with
`mne.viz.plot_ch_adjacency`.
If ``'auto'`` (default), the matrix will be automatically created
using `mne.channels.find_ch_adjacency` and other parameters.
%(n_jobs)s
%(verbose)s
Returns
-------
inst : Raw | Epochs| ChData
The instance modified in place.
Notes
-----
This function requires a full copy of the data in memory.
References
----------
.. footbibliography::
""" # noqa: E501
_check_type(inst, (BaseRaw, BaseEpochs, CHData), item_name="inst")
_check_type(ch_type, (str,), item_name="ch_type")
_check_value(ch_type, ("eeg",), item_name="ch_type")
_check_type(exclude_bads, (bool,), item_name="exclude_bads")
n_jobs = _check_n_jobs(n_jobs)
_check_preload(inst, "Apply spatial filter")
if inst.get_montage() is None:
raise ValueError(
"No montage was set on your data, but spatial filter"
"can only work if digitization points for the EEG "
"channels are available. Consider calling inst.set_montage() "
"to apply a montage."
)
# retrieve picks
picks = dict(_picks_by_type(inst.info, exclude=[]))[ch_type]
info = pick_info(inst.info, picks)
# adjacency matrix
adjacency, ch_names = _check_adjacency(adjacency, info, ch_type)
if exclude_bads:
for c, chan in enumerate(ch_names):
if chan in inst.info["bads"]:
adjacency[c, :] = 0 # do not change bads
adjacency[:, c] = 0 # don't use bads to interpolate
# retrieve channel positions
pos = inst._get_channel_positions(picks)
# test spherical fit
origin = _check_origin(origin, info)
distance = np.linalg.norm(pos - origin, axis=-1)
distance = np.mean(distance / np.mean(distance))
if np.abs(1.0 - distance) > 0.1:
logger.warn(
"Your spherical fit is poor, interpolation results are "
"likely to be inaccurate."
)
pos = pos - origin
interpolate_matrix = _make_interpolation_matrix(pos, pos)
# retrieve data
data = inst.get_data(picks=picks)
if isinstance(inst, BaseEpochs):
data = np.hstack(data)
# apply filter
logger.info(f"Applying spatial filter on {len(picks)} channels.")
if n_jobs == 1:
spatial_filtered_data = []
for index, adjacency_vector in enumerate(adjacency):
channel_data = _channel_spatial_filter(
index, data, adjacency_vector, interpolate_matrix
)
spatial_filtered_data.append(channel_data)
else:
parallel, p_fun, _ = parallel_func(
_channel_spatial_filter, n_jobs, total=len(adjacency)
)
spatial_filtered_data = parallel(
p_fun(index, data, adjacency_vector, interpolate_matrix)
for index, adjacency_vector in enumerate(adjacency)
)

data = np.array(spatial_filtered_data)
if isinstance(inst, BaseEpochs):
data = data.reshape(
(len(picks), inst._data.shape[0], inst._data.shape[-1])
).swapaxes(0, 1)
inst._data[:, picks, :] = data
else:
inst._data[picks] = data

return inst


def _channel_spatial_filter(index, data, adjacency_vector, interpolate_matrix):
neighbors_data = data[adjacency_vector == 1, :]
neighbor_indices = np.argwhere(adjacency_vector == 1)
# too much bads /edge
if len(neighbor_indices) <= 3:
print(index)
return data[index]
# neighbor_matrix shape (n_neighbor, n_samples)
neighbor_matrix = np.array(
[neighbor_indices.flatten().tolist()] * data.shape[-1]
).T

# Create a mask
max_mask = neighbors_data == np.amax(neighbors_data, keepdims=True, axis=0)
min_mask = neighbors_data == np.amin(neighbors_data, keepdims=True, axis=0)
keep_mask = ~(max_mask | min_mask)

keep_indices = np.array(
[
neighbor_matrix[:, i][keep_mask[:, i]]
for i in range(keep_mask.shape[-1])
]
)
channel_data = data[index]
for i, keep_ind in enumerate(keep_indices):
weights = interpolate_matrix[keep_ind, index]
# normalize weights
weights = weights / np.linalg.norm(weights)
# average
channel_data[i] = np.average(data[keep_ind, i], weights=weights)
return channel_data
Loading

0 comments on commit bd412cc

Please sign in to comment.