diff --git a/examples/decoding/decoding_csp_eeg.py b/examples/decoding/decoding_csp_eeg.py index 893e7969c7a..a0e84c00c15 100644 --- a/examples/decoding/decoding_csp_eeg.py +++ b/examples/decoding/decoding_csp_eeg.py @@ -50,6 +50,7 @@ montage = make_standard_montage("standard_1005") raw.set_montage(montage) raw.annotations.rename(dict(T1="hands", T2="feet")) +raw.set_eeg_reference(projection=True) # Apply band-pass filter raw.filter(7.0, 30.0, fir_design="firwin", skip_by_annotation="edge") diff --git a/mne/conftest.py b/mne/conftest.py index fd7d1946843..3bb9a8eb7ce 100644 --- a/mne/conftest.py +++ b/mne/conftest.py @@ -960,6 +960,8 @@ def pytest_sessionfinish(session, exitstatus): # get the number to print files = defaultdict(lambda: 0.0) for item in session.items: + if _phase_report_key not in item.stash: + continue report = item.stash[_phase_report_key] dur = sum(x.duration for x in report.values()) parts = Path(item.nodeid.split(":")[0]).parts diff --git a/mne/cov.py b/mne/cov.py index 7b9a4b24252..f09510a5c32 100644 --- a/mne/cov.py +++ b/mne/cov.py @@ -1250,6 +1250,7 @@ def _compute_covariance_auto( rank, scalings, info, + verbose=_verbose_safe_false(), ) with _scaled_array(data.T, picks_list, scalings): C = np.dot(data.T, data) diff --git a/mne/decoding/csp.py b/mne/decoding/csp.py index ac3983e4617..ce91f54d32f 100644 --- a/mne/decoding/csp.py +++ b/mne/decoding/csp.py @@ -12,7 +12,8 @@ import numpy as np from scipy.linalg import eigh -from ..cov import _regularized_covariance +from .._fiff.meas_info import create_info +from ..cov import _regularized_covariance, _smart_eigh from ..defaults import _BORDER_DEFAULT, _EXTRAPOLATE_DEFAULT, _INTERPOLATION_DEFAULT from ..evoked import EvokedArray from ..fixes import pinv @@ -185,6 +186,9 @@ def fit(self, X, y): f"{n_classes} classes; use component_order='mutual_info' instead." ) + # Convert rank to one that will run + _validate_type(self.rank, (dict, None), "rank") + covs, sample_weights = self._compute_covariance_matrices(X, y) eigen_vectors, eigen_values = self._decompose_covs(covs, sample_weights) ix = self._order_components( @@ -519,6 +523,13 @@ def _compute_covariance_matrices(self, X, y): elif self.cov_est == "epoch": cov_estimator = self._epoch_cov + # TODO: We should allow the user to pass this, then we won't need to convert + self._info = create_info(n_channels, 1000.0, "mag") + if self.rank is None: + self._rank = self.rank + else: + self._rank = {"mag": sum(self.rank.values())} + covs = [] sample_weights = [] for this_class in self._classes: @@ -539,7 +550,11 @@ def _concat_cov(self, x_class): x_class = np.transpose(x_class, [1, 0, 2]) x_class = x_class.reshape(n_channels, -1) cov = _regularized_covariance( - x_class, reg=self.reg, method_params=self.cov_method_params, rank=self.rank + x_class, + reg=self.reg, + method_params=self.cov_method_params, + rank=self._rank, + info=self._info, ) weight = x_class.shape[0] @@ -552,7 +567,8 @@ def _epoch_cov(self, x_class): this_X, reg=self.reg, method_params=self.cov_method_params, - rank=self.rank, + rank=self._rank, + info=self._info, ) for this_X in x_class ) @@ -563,6 +579,17 @@ def _epoch_cov(self, x_class): def _decompose_covs(self, covs, sample_weights): n_classes = len(covs) + n_channels = covs[0].shape[0] + _, sub_vec, mask = _smart_eigh( + covs.mean(0), + self._info, + self._rank, + proj_subspace=True, + do_compute_rank=self._rank is None, + ) + sub_vec = sub_vec[mask] + covs = np.array([sub_vec @ cov @ sub_vec.T for cov in covs], float) + assert covs[0].shape == (mask.sum(),) * 2 if n_classes == 2: eigen_values, eigen_vectors = eigh(covs[0], covs.sum(0)) else: @@ -573,6 +600,9 @@ def _decompose_covs(self, covs, sample_weights): eigen_vectors.T, covs, sample_weights ) eigen_values = None + # project back + eigen_vectors = sub_vec.T @ eigen_vectors + assert eigen_vectors.shape == (n_channels, mask.sum()) return eigen_vectors, eigen_values def _compute_mutual_info(self, covs, sample_weights, eigen_vectors): diff --git a/mne/decoding/tests/test_csp.py b/mne/decoding/tests/test_csp.py index e632a02e2a7..039b234b5ec 100644 --- a/mne/decoding/tests/test_csp.py +++ b/mne/decoding/tests/test_csp.py @@ -245,40 +245,51 @@ def test_csp(): assert np.abs(corr) > 0.95 -def test_regularized_csp(): +# Even the "reg is None and rank is None" case should pass now thanks to the +# do_compute_rank +@pytest.mark.parametrize("ch_type", ("mag", "eeg")) +@pytest.mark.parametrize("rank", (None, "correct")) +@pytest.mark.parametrize("reg", [None, 0.05, "ledoit_wolf", "oas"]) +def test_regularized_csp(ch_type, rank, reg): """Test Common Spatial Patterns algorithm using regularized covariance.""" pytest.importorskip("sklearn") - raw = io.read_raw_fif(raw_fname) + raw = io.read_raw_fif(raw_fname).pick(ch_type, exclude="bads") + raw.pick(raw.ch_names[:30]).load_data() + if ch_type == "eeg": + raw.set_eeg_reference(projection=True) + n_eig = len(raw.ch_names) - len(raw.info["projs"]) + if ch_type == "eeg": + assert n_eig == 29 + else: + assert n_eig == 27 + if rank == "correct": + rank = {ch_type: n_eig} + else: + assert rank is None, rank + raw.info.normalize_proj() events = read_events(event_name) - picks = pick_types( - raw.info, meg=True, stim=False, ecg=False, eog=False, exclude="bads" - ) - picks = picks[1:13:3] - epochs = Epochs( - raw, events, event_id, tmin, tmax, picks=picks, baseline=(None, 0), preload=True - ) + epochs = Epochs(raw, events, event_id, tmin, tmax, baseline=(None, 0), preload=True) epochs_data = epochs.get_data(copy=False) n_channels = epochs_data.shape[1] - + assert n_channels == 30 n_components = 3 - reg_cov = [None, 0.05, "ledoit_wolf", "oas"] - for reg in reg_cov: - csp = CSP(n_components=n_components, reg=reg, norm_trace=False, rank=None) - csp.fit(epochs_data, epochs.events[:, -1]) - y = epochs.events[:, -1] - X = csp.fit_transform(epochs_data, y) - assert csp.filters_.shape == (n_channels, n_channels) - assert csp.patterns_.shape == (n_channels, n_channels) - assert_array_almost_equal(csp.fit(epochs_data, y).transform(epochs_data), X) - - # test init exception - pytest.raises(ValueError, csp.fit, epochs_data, np.zeros_like(epochs.events)) - pytest.raises(ValueError, csp.fit, epochs, y) - pytest.raises(ValueError, csp.transform, epochs) - - csp.n_components = n_components - sources = csp.transform(epochs_data) - assert sources.shape[1] == n_components + + csp = CSP(n_components=n_components, reg=reg, norm_trace=False, rank=rank) + csp.fit(epochs_data, epochs.events[:, -1]) + y = epochs.events[:, -1] + X = csp.fit_transform(epochs_data, y) + assert csp.filters_.shape == (n_eig, n_channels) + assert csp.patterns_.shape == (n_eig, n_channels) + assert_array_almost_equal(csp.fit(epochs_data, y).transform(epochs_data), X) + + # test init exception + pytest.raises(ValueError, csp.fit, epochs_data, np.zeros_like(epochs.events)) + pytest.raises(ValueError, csp.fit, epochs, y) + pytest.raises(ValueError, csp.transform, epochs) + + csp.n_components = n_components + sources = csp.transform(epochs_data) + assert sources.shape[1] == n_components def test_csp_pipeline():