-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add apply_spatial_filter * Revert changes in Mixin and remove support for ChData * change preload to property * Reorder __all__ * only handle "eeg" picks * Update bads handling * Fix data ordering for epochs * Add docs * Create 20_spatial_filter.py * Update latest.rst * WIP tutorial * Apply suggestions from code review Co-authored-by: Mathieu Scheltienne <[email protected]> * Rename tutorial * resolve reviews * Add test for non eeg data * Fix style * Update references.bib * fix doc * Add support for CHdata by adding _get_channel_positions method to it * ADD test for ChData._get_channel_positions method * improve tutorial * Fix style * Fix style * Update 20_spatial_filter.py * Update 20_spatial_filter.py * Update ch_data.py * Update 20_spatial_filter.py * Apply suggestions from code review Co-authored-by: Mathieu Scheltienne <[email protected]> * fix pos * add origin check * minor import and type-hints changes * test against _typing.CHData instead of importing io.ChData * fix intersphinx links for numpy and scipy * improve tutorial * Apply suggestions from code review Co-authored-by: Mathieu Scheltienne <[email protected]> * WIP * Add handle for custom adjacency * Fix style * Fix docstring * Fix reference target not found * Fix doc * Fix style * Add note and logs * fix * Fix style * minor docstring improvement * only use public function in tutorial * test * revert --------- Co-authored-by: Mathieu Scheltienne <[email protected]> Co-authored-by: Mathieu Scheltienne <[email protected]>
- Loading branch information
1 parent
91df686
commit 9e83ccb
Showing
11 changed files
with
547 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,5 +10,6 @@ Preprocessing | |
.. autosummary:: | ||
:toctree: generated/ | ||
|
||
apply_spatial_filter | ||
extract_gfp_peaks | ||
resample |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,232 @@ | ||
from typing import Union | ||
|
||
import numpy as np | ||
from mne import BaseEpochs, pick_info | ||
from mne.bem import _check_origin | ||
from mne.channels import find_ch_adjacency | ||
from mne.channels.interpolation import _make_interpolation_matrix | ||
from mne.io import BaseRaw | ||
from mne.io.pick import _picks_by_type | ||
from mne.parallel import parallel_func | ||
from mne.utils.check import _check_preload | ||
from numpy.typing import NDArray | ||
from scipy.sparse import csr_matrix | ||
|
||
from .._typing import CHData | ||
from ..utils._checks import _check_n_jobs, _check_type, _check_value | ||
from ..utils._docs import fill_doc | ||
from ..utils._logs import logger, verbose | ||
|
||
|
||
def _check_adjacency(adjacency, info, ch_type): | ||
"""Check adjacency matrix.""" | ||
_check_type(adjacency, (csr_matrix, np.ndarray, str), "adjacency") | ||
# auto | ||
if isinstance(adjacency, str): | ||
if adjacency != "auto": | ||
raise ( | ||
ValueError( | ||
f"Adjacency can be either a scipy.sparse.csr_matrix" | ||
f"or 'auto' but got string '{adjacency}' instead." | ||
) | ||
) | ||
adjacency, ch_names = find_ch_adjacency(info, ch_type) | ||
# custom | ||
if isinstance(adjacency, csr_matrix): | ||
adjacency = adjacency.toarray() | ||
if isinstance(adjacency, np.ndarray): | ||
ch_names = info.ch_names | ||
n_channels = len(ch_names) | ||
if adjacency.ndim != 2: | ||
raise ValueError( | ||
"Adjacency must have exactly 2 dimensions but got " | ||
f"{adjacency.ndim} dimensions instead." | ||
) | ||
if (adjacency.shape[0] != n_channels) or ( | ||
adjacency.shape[1] != n_channels | ||
): | ||
raise ValueError( | ||
"Adjacency must be of shape (n_channels, n_channels) " | ||
f"but got {adjacency.shape} instead." | ||
) | ||
if not np.array_equal(adjacency, adjacency.astype(bool)): | ||
raise ValueError( | ||
"Values contained in adjacency can only be 0 or 1." | ||
) | ||
return (adjacency, ch_names) | ||
|
||
|
||
@fill_doc | ||
@verbose | ||
def apply_spatial_filter( | ||
inst: Union[BaseRaw, BaseEpochs, CHData], | ||
ch_type: str = "eeg", | ||
exclude_bads: bool = True, | ||
origin: Union[str, NDArray[float]] = "auto", | ||
adjacency: Union[csr_matrix, str] = "auto", | ||
n_jobs: int = 1, | ||
verbose=None, | ||
): | ||
r"""Apply a spatial filter. | ||
Adapted from \ :footcite:t:`michel2019eeg`. | ||
Apply an instantaneous filter which interpolates channels | ||
with local neighbors while removing outliers. | ||
The current implementation proceeds as follows: | ||
* An interpolation matrix is computed using | ||
mne.channels.interpolation._make_interpolation_matrix. | ||
* An ajdacency matrix is computed using | ||
`mne.channels.find_ch_adjacency`. | ||
* If ``exclude_bads`` is set to ``True``, | ||
bad channels are removed from the ajdacency matrix. | ||
* For each timepoint and each channel, | ||
a reduced adjacency matrix is built by removing neighbors | ||
with lowest and highest value. | ||
* For each timepoint and each channel, | ||
a reduced interpolation matrix is built by extracting neighbor | ||
weights based on the reduced adjacency matrix. | ||
* The reduced interpolation matrices are normalized. | ||
* The channel's timepoints are interpolated | ||
using their reduced interpolation matrix. | ||
Parameters | ||
---------- | ||
inst : Raw | Epochs | ChData | ||
Instance to filter spatially. | ||
ch_type : str | ||
The channel type on which to apply the spatial filter. | ||
Currently only supports ``'eeg'``. | ||
exclude_bads : bool | ||
If set to ``True``, bad channels will be removed | ||
from the adjacency matrix and therefore not used | ||
to interpolate neighbors. In addition, bad channels | ||
will not be filtered. | ||
If set to ``False``, proceed as if all channels were good. | ||
origin : array of shape (3,) | str | ||
Origin of the sphere in the head coordinate frame and in meters. | ||
Can be ``'auto'`` (default), which means a head-digitization-based | ||
origin fit. | ||
adjacency : array or csr_matrix of shape (n_channels, n_channels) | str | ||
An adjacency matrix. Can be created using | ||
`mne.channels.find_ch_adjacency` and edited with | ||
`mne.viz.plot_ch_adjacency`. | ||
If ``'auto'`` (default), the matrix will be automatically created | ||
using `mne.channels.find_ch_adjacency` and other parameters. | ||
%(n_jobs)s | ||
%(verbose)s | ||
Returns | ||
------- | ||
inst : Raw | Epochs| ChData | ||
The instance modified in place. | ||
Notes | ||
----- | ||
This function requires a full copy of the data in memory. | ||
References | ||
---------- | ||
.. footbibliography:: | ||
""" # noqa: E501 | ||
_check_type(inst, (BaseRaw, BaseEpochs, CHData), item_name="inst") | ||
_check_type(ch_type, (str,), item_name="ch_type") | ||
_check_value(ch_type, ("eeg",), item_name="ch_type") | ||
_check_type(exclude_bads, (bool,), item_name="exclude_bads") | ||
n_jobs = _check_n_jobs(n_jobs) | ||
_check_preload(inst, "Apply spatial filter") | ||
if inst.get_montage() is None: | ||
raise ValueError( | ||
"No montage was set on your data, but spatial filter" | ||
"can only work if digitization points for the EEG " | ||
"channels are available. Consider calling inst.set_montage() " | ||
"to apply a montage." | ||
) | ||
# retrieve picks | ||
picks = dict(_picks_by_type(inst.info, exclude=[]))[ch_type] | ||
info = pick_info(inst.info, picks) | ||
# adjacency matrix | ||
adjacency, ch_names = _check_adjacency(adjacency, info, ch_type) | ||
if exclude_bads: | ||
for c, chan in enumerate(ch_names): | ||
if chan in inst.info["bads"]: | ||
adjacency[c, :] = 0 # do not change bads | ||
adjacency[:, c] = 0 # don't use bads to interpolate | ||
# retrieve channel positions | ||
pos = inst._get_channel_positions(picks) | ||
# test spherical fit | ||
origin = _check_origin(origin, info) | ||
distance = np.linalg.norm(pos - origin, axis=-1) | ||
distance = np.mean(distance / np.mean(distance)) | ||
if np.abs(1.0 - distance) > 0.1: | ||
logger.warn( | ||
"Your spherical fit is poor, interpolation results are " | ||
"likely to be inaccurate." | ||
) | ||
pos = pos - origin | ||
interpolate_matrix = _make_interpolation_matrix(pos, pos) | ||
# retrieve data | ||
data = inst.get_data(picks=picks) | ||
if isinstance(inst, BaseEpochs): | ||
data = np.hstack(data) | ||
# apply filter | ||
logger.info(f"Applying spatial filter on {len(picks)} channels.") | ||
if n_jobs == 1: | ||
spatial_filtered_data = [] | ||
for index, adjacency_vector in enumerate(adjacency): | ||
channel_data = _channel_spatial_filter( | ||
index, data, adjacency_vector, interpolate_matrix | ||
) | ||
spatial_filtered_data.append(channel_data) | ||
else: | ||
parallel, p_fun, _ = parallel_func( | ||
_channel_spatial_filter, n_jobs, total=len(adjacency) | ||
) | ||
spatial_filtered_data = parallel( | ||
p_fun(index, data, adjacency_vector, interpolate_matrix) | ||
for index, adjacency_vector in enumerate(adjacency) | ||
) | ||
|
||
data = np.array(spatial_filtered_data) | ||
if isinstance(inst, BaseEpochs): | ||
data = data.reshape( | ||
(len(picks), inst._data.shape[0], inst._data.shape[-1]) | ||
).swapaxes(0, 1) | ||
inst._data[:, picks, :] = data | ||
else: | ||
inst._data[picks] = data | ||
|
||
return inst | ||
|
||
|
||
def _channel_spatial_filter(index, data, adjacency_vector, interpolate_matrix): | ||
neighbors_data = data[adjacency_vector == 1, :] | ||
neighbor_indices = np.argwhere(adjacency_vector == 1) | ||
# too much bads /edge | ||
if len(neighbor_indices) <= 3: | ||
print(index) | ||
return data[index] | ||
# neighbor_matrix shape (n_neighbor, n_samples) | ||
neighbor_matrix = np.array( | ||
[neighbor_indices.flatten().tolist()] * data.shape[-1] | ||
).T | ||
|
||
# Create a mask | ||
max_mask = neighbors_data == np.amax(neighbors_data, keepdims=True, axis=0) | ||
min_mask = neighbors_data == np.amin(neighbors_data, keepdims=True, axis=0) | ||
keep_mask = ~(max_mask | min_mask) | ||
|
||
keep_indices = np.array( | ||
[ | ||
neighbor_matrix[:, i][keep_mask[:, i]] | ||
for i in range(keep_mask.shape[-1]) | ||
] | ||
) | ||
channel_data = data[index] | ||
for i, keep_ind in enumerate(keep_indices): | ||
weights = interpolate_matrix[keep_ind, index] | ||
# normalize weights | ||
weights = weights / np.linalg.norm(weights) | ||
# average | ||
channel_data[i] = np.average(data[keep_ind, i], weights=weights) | ||
return channel_data |
Oops, something went wrong.