Skip to content

Commit

Permalink
BUG: Fix bug with CSP rank
Browse files Browse the repository at this point in the history
  • Loading branch information
larsoner committed Mar 1, 2024
1 parent 668b508 commit d9a7d8e
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 31 deletions.
1 change: 1 addition & 0 deletions examples/decoding/decoding_csp_eeg.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
montage = make_standard_montage("standard_1005")
raw.set_montage(montage)
raw.annotations.rename(dict(T1="hands", T2="feet"))
raw.set_eeg_reference(projection=True)

# Apply band-pass filter
raw.filter(7.0, 30.0, fir_design="firwin", skip_by_annotation="edge")
Expand Down
2 changes: 2 additions & 0 deletions mne/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -960,6 +960,8 @@ def pytest_sessionfinish(session, exitstatus):
# get the number to print
files = defaultdict(lambda: 0.0)
for item in session.items:
if _phase_report_key not in item.stash:
continue
report = item.stash[_phase_report_key]
dur = sum(x.duration for x in report.values())
parts = Path(item.nodeid.split(":")[0]).parts
Expand Down
1 change: 1 addition & 0 deletions mne/cov.py
Original file line number Diff line number Diff line change
Expand Up @@ -1250,6 +1250,7 @@ def _compute_covariance_auto(
rank,
scalings,
info,
verbose=_verbose_safe_false(),
)
with _scaled_array(data.T, picks_list, scalings):
C = np.dot(data.T, data)
Expand Down
36 changes: 33 additions & 3 deletions mne/decoding/csp.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
import numpy as np
from scipy.linalg import eigh

from ..cov import _regularized_covariance
from .._fiff.meas_info import create_info
from ..cov import _regularized_covariance, _smart_eigh
from ..defaults import _BORDER_DEFAULT, _EXTRAPOLATE_DEFAULT, _INTERPOLATION_DEFAULT
from ..evoked import EvokedArray
from ..fixes import pinv
Expand Down Expand Up @@ -185,6 +186,9 @@ def fit(self, X, y):
f"{n_classes} classes; use component_order='mutual_info' instead."
)

# Convert rank to one that will run
_validate_type(self.rank, (dict, None), "rank")

covs, sample_weights = self._compute_covariance_matrices(X, y)
eigen_vectors, eigen_values = self._decompose_covs(covs, sample_weights)
ix = self._order_components(
Expand Down Expand Up @@ -519,6 +523,13 @@ def _compute_covariance_matrices(self, X, y):
elif self.cov_est == "epoch":
cov_estimator = self._epoch_cov

# TODO: We should allow the user to pass this, then we won't need to convert
self._info = create_info(n_channels, 1000.0, "mag")
if self.rank is None:
self._rank = self.rank
else:
self._rank = {"mag": sum(self.rank.values())}

covs = []
sample_weights = []
for this_class in self._classes:
Expand All @@ -539,7 +550,11 @@ def _concat_cov(self, x_class):
x_class = np.transpose(x_class, [1, 0, 2])
x_class = x_class.reshape(n_channels, -1)
cov = _regularized_covariance(
x_class, reg=self.reg, method_params=self.cov_method_params, rank=self.rank
x_class,
reg=self.reg,
method_params=self.cov_method_params,
rank=self._rank,
info=self._info,
)
weight = x_class.shape[0]

Expand All @@ -552,7 +567,8 @@ def _epoch_cov(self, x_class):
this_X,
reg=self.reg,
method_params=self.cov_method_params,
rank=self.rank,
rank=self._rank,
info=self._info,
)
for this_X in x_class
)
Expand All @@ -563,6 +579,17 @@ def _epoch_cov(self, x_class):

def _decompose_covs(self, covs, sample_weights):
n_classes = len(covs)
n_channels = covs[0].shape[0]
_, sub_vec, mask = _smart_eigh(
covs.mean(0),
self._info,
self._rank,
proj_subspace=True,
do_compute_rank=self._rank is None,
)
sub_vec = sub_vec[mask]
covs = np.array([sub_vec @ cov @ sub_vec.T for cov in covs], float)
assert covs[0].shape == (mask.sum(),) * 2
if n_classes == 2:
eigen_values, eigen_vectors = eigh(covs[0], covs.sum(0))
else:
Expand All @@ -573,6 +600,9 @@ def _decompose_covs(self, covs, sample_weights):
eigen_vectors.T, covs, sample_weights
)
eigen_values = None
# project back
eigen_vectors = sub_vec.T @ eigen_vectors
assert eigen_vectors.shape == (n_channels, mask.sum())
return eigen_vectors, eigen_values

def _compute_mutual_info(self, covs, sample_weights, eigen_vectors):
Expand Down
67 changes: 39 additions & 28 deletions mne/decoding/tests/test_csp.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,40 +245,51 @@ def test_csp():
assert np.abs(corr) > 0.95


def test_regularized_csp():
# Even the "reg is None and rank is None" case should pass now thanks to the
# do_compute_rank
@pytest.mark.parametrize("ch_type", ("mag", "eeg"))
@pytest.mark.parametrize("rank", (None, "correct"))
@pytest.mark.parametrize("reg", [None, 0.05, "ledoit_wolf", "oas"])
def test_regularized_csp(ch_type, rank, reg):
"""Test Common Spatial Patterns algorithm using regularized covariance."""
pytest.importorskip("sklearn")
raw = io.read_raw_fif(raw_fname)
raw = io.read_raw_fif(raw_fname).pick(ch_type, exclude="bads")
raw.pick(raw.ch_names[:30]).load_data()
if ch_type == "eeg":
raw.set_eeg_reference(projection=True)
n_eig = len(raw.ch_names) - len(raw.info["projs"])
if ch_type == "eeg":
assert n_eig == 29
else:
assert n_eig == 27
if rank == "correct":
rank = {ch_type: n_eig}
else:
assert rank is None, rank
raw.info.normalize_proj()
events = read_events(event_name)
picks = pick_types(
raw.info, meg=True, stim=False, ecg=False, eog=False, exclude="bads"
)
picks = picks[1:13:3]
epochs = Epochs(
raw, events, event_id, tmin, tmax, picks=picks, baseline=(None, 0), preload=True
)
epochs = Epochs(raw, events, event_id, tmin, tmax, baseline=(None, 0), preload=True)
epochs_data = epochs.get_data(copy=False)
n_channels = epochs_data.shape[1]

assert n_channels == 30
n_components = 3
reg_cov = [None, 0.05, "ledoit_wolf", "oas"]
for reg in reg_cov:
csp = CSP(n_components=n_components, reg=reg, norm_trace=False, rank=None)
csp.fit(epochs_data, epochs.events[:, -1])
y = epochs.events[:, -1]
X = csp.fit_transform(epochs_data, y)
assert csp.filters_.shape == (n_channels, n_channels)
assert csp.patterns_.shape == (n_channels, n_channels)
assert_array_almost_equal(csp.fit(epochs_data, y).transform(epochs_data), X)

# test init exception
pytest.raises(ValueError, csp.fit, epochs_data, np.zeros_like(epochs.events))
pytest.raises(ValueError, csp.fit, epochs, y)
pytest.raises(ValueError, csp.transform, epochs)

csp.n_components = n_components
sources = csp.transform(epochs_data)
assert sources.shape[1] == n_components

csp = CSP(n_components=n_components, reg=reg, norm_trace=False, rank=rank)
csp.fit(epochs_data, epochs.events[:, -1])
y = epochs.events[:, -1]
X = csp.fit_transform(epochs_data, y)
assert csp.filters_.shape == (n_eig, n_channels)
assert csp.patterns_.shape == (n_eig, n_channels)
assert_array_almost_equal(csp.fit(epochs_data, y).transform(epochs_data), X)

# test init exception
pytest.raises(ValueError, csp.fit, epochs_data, np.zeros_like(epochs.events))
pytest.raises(ValueError, csp.fit, epochs, y)
pytest.raises(ValueError, csp.transform, epochs)

csp.n_components = n_components
sources = csp.transform(epochs_data)
assert sources.shape[1] == n_components


def test_csp_pipeline():
Expand Down

0 comments on commit d9a7d8e

Please sign in to comment.