-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' of https://github.com/vferat/pycrostates
- Loading branch information
Showing
10 changed files
with
545 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,5 +10,6 @@ Preprocessing | |
.. autosummary:: | ||
:toctree: generated/ | ||
|
||
apply_spatial_filter | ||
extract_gfp_peaks | ||
resample |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,232 @@ | ||
from typing import Union | ||
|
||
import numpy as np | ||
from mne import BaseEpochs, pick_info | ||
from mne.bem import _check_origin | ||
from mne.channels import find_ch_adjacency | ||
from mne.channels.interpolation import _make_interpolation_matrix | ||
from mne.io import BaseRaw | ||
from mne.io.pick import _picks_by_type | ||
from mne.parallel import parallel_func | ||
from mne.utils.check import _check_preload | ||
from numpy.typing import NDArray | ||
from scipy.sparse import csr_matrix | ||
|
||
from .._typing import CHData | ||
from ..utils._checks import _check_n_jobs, _check_type, _check_value | ||
from ..utils._docs import fill_doc | ||
from ..utils._logs import logger, verbose | ||
|
||
|
||
def _check_adjacency(adjacency, info, ch_type): | ||
"""Check adjacency matrix.""" | ||
_check_type(adjacency, (csr_matrix, np.ndarray, str), "adjacency") | ||
# auto | ||
if isinstance(adjacency, str): | ||
if adjacency != "auto": | ||
raise ( | ||
ValueError( | ||
f"Adjacency can be either a scipy.sparse.csr_matrix" | ||
f"or 'auto' but got string '{adjacency}' instead." | ||
) | ||
) | ||
adjacency, ch_names = find_ch_adjacency(info, ch_type) | ||
# custom | ||
if isinstance(adjacency, csr_matrix): | ||
adjacency = adjacency.toarray() | ||
if isinstance(adjacency, np.ndarray): | ||
ch_names = info.ch_names | ||
n_channels = len(ch_names) | ||
if adjacency.ndim != 2: | ||
raise ValueError( | ||
"Adjacency must have exactly 2 dimensions but got " | ||
f"{adjacency.ndim} dimensions instead." | ||
) | ||
if (adjacency.shape[0] != n_channels) or ( | ||
adjacency.shape[1] != n_channels | ||
): | ||
raise ValueError( | ||
"Adjacency must be of shape (n_channels, n_channels) " | ||
f"but got {adjacency.shape} instead." | ||
) | ||
if not np.array_equal(adjacency, adjacency.astype(bool)): | ||
raise ValueError( | ||
"Values contained in adjacency can only be 0 or 1." | ||
) | ||
return (adjacency, ch_names) | ||
|
||
|
||
@fill_doc | ||
@verbose | ||
def apply_spatial_filter( | ||
inst: Union[BaseRaw, BaseEpochs, CHData], | ||
ch_type: str = "eeg", | ||
exclude_bads: bool = True, | ||
origin: Union[str, NDArray[float]] = "auto", | ||
adjacency: Union[csr_matrix, str] = "auto", | ||
n_jobs: int = 1, | ||
verbose=None, | ||
): | ||
r"""Apply a spatial filter. | ||
Adapted from \ :footcite:t:`michel2019eeg`. | ||
Apply an instantaneous filter which interpolates channels | ||
with local neighbors while removing outliers. | ||
The current implementation proceeds as follows: | ||
* An interpolation matrix is computed using | ||
mne.channels.interpolation._make_interpolation_matrix. | ||
* An ajdacency matrix is computed using | ||
`mne.channels.find_ch_adjacency`. | ||
* If ``exclude_bads`` is set to ``True``, | ||
bad channels are removed from the ajdacency matrix. | ||
* For each timepoint and each channel, | ||
a reduced adjacency matrix is built by removing neighbors | ||
with lowest and highest value. | ||
* For each timepoint and each channel, | ||
a reduced interpolation matrix is built by extracting neighbor | ||
weights based on the reduced adjacency matrix. | ||
* The reduced interpolation matrices are normalized. | ||
* The channel's timepoints are interpolated | ||
using their reduced interpolation matrix. | ||
Parameters | ||
---------- | ||
inst : Raw | Epochs | ChData | ||
Instance to filter spatially. | ||
ch_type : str | ||
The channel type on which to apply the spatial filter. | ||
Currently only supports ``'eeg'``. | ||
exclude_bads : bool | ||
If set to ``True``, bad channels will be removed | ||
from the adjacency matrix and therefore not used | ||
to interpolate neighbors. In addition, bad channels | ||
will not be filtered. | ||
If set to ``False``, proceed as if all channels were good. | ||
origin : array of shape (3,) | str | ||
Origin of the sphere in the head coordinate frame and in meters. | ||
Can be ``'auto'`` (default), which means a head-digitization-based | ||
origin fit. | ||
adjacency : array or csr_matrix of shape (n_channels, n_channels) | str | ||
An adjacency matrix. Can be created using | ||
`mne.channels.find_ch_adjacency` and edited with | ||
`mne.viz.plot_ch_adjacency`. | ||
If ``'auto'`` (default), the matrix will be automatically created | ||
using `mne.channels.find_ch_adjacency` and other parameters. | ||
%(n_jobs)s | ||
%(verbose)s | ||
Returns | ||
------- | ||
inst : Raw | Epochs| ChData | ||
The instance modified in place. | ||
Notes | ||
----- | ||
This function requires a full copy of the data in memory. | ||
References | ||
---------- | ||
.. footbibliography:: | ||
""" # noqa: E501 | ||
_check_type(inst, (BaseRaw, BaseEpochs, CHData), item_name="inst") | ||
_check_type(ch_type, (str,), item_name="ch_type") | ||
_check_value(ch_type, ("eeg",), item_name="ch_type") | ||
_check_type(exclude_bads, (bool,), item_name="exclude_bads") | ||
n_jobs = _check_n_jobs(n_jobs) | ||
_check_preload(inst, "Apply spatial filter") | ||
if inst.get_montage() is None: | ||
raise ValueError( | ||
"No montage was set on your data, but spatial filter" | ||
"can only work if digitization points for the EEG " | ||
"channels are available. Consider calling inst.set_montage() " | ||
"to apply a montage." | ||
) | ||
# retrieve picks | ||
picks = dict(_picks_by_type(inst.info, exclude=[]))[ch_type] | ||
info = pick_info(inst.info, picks) | ||
# adjacency matrix | ||
adjacency, ch_names = _check_adjacency(adjacency, info, ch_type) | ||
if exclude_bads: | ||
for c, chan in enumerate(ch_names): | ||
if chan in inst.info["bads"]: | ||
adjacency[c, :] = 0 # do not change bads | ||
adjacency[:, c] = 0 # don't use bads to interpolate | ||
# retrieve channel positions | ||
pos = inst._get_channel_positions(picks) | ||
# test spherical fit | ||
origin = _check_origin(origin, info) | ||
distance = np.linalg.norm(pos - origin, axis=-1) | ||
distance = np.mean(distance / np.mean(distance)) | ||
if np.abs(1.0 - distance) > 0.1: | ||
logger.warn( | ||
"Your spherical fit is poor, interpolation results are " | ||
"likely to be inaccurate." | ||
) | ||
pos = pos - origin | ||
interpolate_matrix = _make_interpolation_matrix(pos, pos) | ||
# retrieve data | ||
data = inst.get_data(picks=picks) | ||
if isinstance(inst, BaseEpochs): | ||
data = np.hstack(data) | ||
# apply filter | ||
logger.info(f"Applying spatial filter on {len(picks)} channels.") | ||
if n_jobs == 1: | ||
spatial_filtered_data = [] | ||
for index, adjacency_vector in enumerate(adjacency): | ||
channel_data = _channel_spatial_filter( | ||
index, data, adjacency_vector, interpolate_matrix | ||
) | ||
spatial_filtered_data.append(channel_data) | ||
else: | ||
parallel, p_fun, _ = parallel_func( | ||
_channel_spatial_filter, n_jobs, total=len(adjacency) | ||
) | ||
spatial_filtered_data = parallel( | ||
p_fun(index, data, adjacency_vector, interpolate_matrix) | ||
for index, adjacency_vector in enumerate(adjacency) | ||
) | ||
|
||
data = np.array(spatial_filtered_data) | ||
if isinstance(inst, BaseEpochs): | ||
data = data.reshape( | ||
(len(picks), inst._data.shape[0], inst._data.shape[-1]) | ||
).swapaxes(0, 1) | ||
inst._data[:, picks, :] = data | ||
else: | ||
inst._data[picks] = data | ||
|
||
return inst | ||
|
||
|
||
def _channel_spatial_filter(index, data, adjacency_vector, interpolate_matrix): | ||
neighbors_data = data[adjacency_vector == 1, :] | ||
neighbor_indices = np.argwhere(adjacency_vector == 1) | ||
# too much bads /edge | ||
if len(neighbor_indices) <= 3: | ||
print(index) | ||
return data[index] | ||
# neighbor_matrix shape (n_neighbor, n_samples) | ||
neighbor_matrix = np.array( | ||
[neighbor_indices.flatten().tolist()] * data.shape[-1] | ||
).T | ||
|
||
# Create a mask | ||
max_mask = neighbors_data == np.amax(neighbors_data, keepdims=True, axis=0) | ||
min_mask = neighbors_data == np.amin(neighbors_data, keepdims=True, axis=0) | ||
keep_mask = ~(max_mask | min_mask) | ||
|
||
keep_indices = np.array( | ||
[ | ||
neighbor_matrix[:, i][keep_mask[:, i]] | ||
for i in range(keep_mask.shape[-1]) | ||
] | ||
) | ||
channel_data = data[index] | ||
for i, keep_ind in enumerate(keep_indices): | ||
weights = interpolate_matrix[keep_ind, index] | ||
# normalize weights | ||
weights = weights / np.linalg.norm(weights) | ||
# average | ||
channel_data[i] = np.average(data[keep_ind, i], weights=weights) | ||
return channel_data |
Oops, something went wrong.