diff --git a/coffeine/__init__.py b/coffeine/__init__.py index 8142eb5..bf4dca9 100644 --- a/coffeine/__init__.py +++ b/coffeine/__init__.py @@ -29,3 +29,5 @@ from .power_features import compute_features, get_frequency_bands, compute_coffeine, make_coffeine_data_frame # noqa from .spatial_filters import ProjIdentitySpace, ProjCommonSpace, ProjLWSpace, ProjRandomSpace, ProjSPoCSpace # noqa + +from .transfer_learning import ReCenter, ReScale # noqa \ No newline at end of file diff --git a/coffeine/pipelines.py b/coffeine/pipelines.py index 8533af3..21aa292 100644 --- a/coffeine/pipelines.py +++ b/coffeine/pipelines.py @@ -16,12 +16,20 @@ ProjRandomSpace, ProjSPoCSpace) +from coffeine.transfer_learning import ( + ReCenter, + ReScale +) + +import sklearn from sklearn.base import BaseEstimator, TransformerMixin from sklearn.compose import make_column_transformer from sklearn.pipeline import make_pipeline, Pipeline from sklearn.preprocessing import StandardScaler from sklearn.linear_model import RidgeCV, LogisticRegression +sklearn.set_config(enable_metadata_routing=True) + class GaussianKernel(BaseEstimator, TransformerMixin): """Gaussian (squared exponential) Kernel. @@ -142,6 +150,7 @@ def transform(self, def make_filter_bank_transformer( names: list[str], method: str = 'riemann', + alignment: Union[list[str], None] = None, projection_params: Union[dict, None] = None, vectorization_params: Union[dict, None] = None, kernel: Union[str, Pipeline, None] = None, @@ -184,6 +193,9 @@ def make_filter_bank_transformer( to ``'riemann'``. Can be ``'riemann'``, ``'lw_riemann'``, ``'diag'``, ``'log_diag'``, ``'random'``, ``'naive'``, ``'spoc'``, ``'riemann_wasserstein'``. + alignment : list of str | None + Alignment steps to include in the pipeline. Can be ``'re-center'``, + ``'re-scale'``. projection_params : dict | None The parameters for the projection step. vectorization_params : dict | None @@ -249,13 +261,28 @@ def make_filter_bank_transformer( if vectorization_params is not None: vectorization_params_.update(**vectorization_params) - def _get_projector_vectorizer(projection, vectorization, kernel=None): + def _get_projector_vectorizer(projection, vectorization, + recenter, rescale, + kernel=None): out = list() for name in names: - steps = [ - projection(**projection_params_), - vectorization(**vectorization_params_) - ] + steps = [projection(**projection_params_)] + + if recenter is not None: + steps.append( + recenter().set_fit_request( + domains=True + ).set_transform_request(domains=True) + ) + + if rescale is not None: + steps.append( + rescale().set_fit_request( + domains=True + ).set_transform_request(domains=True) + ) + steps.append(vectorization(**vectorization_params_)) + if kernel is not None: kernel_name, kernel_estimator = kernel steps.append(kernel_estimator()) @@ -282,6 +309,16 @@ def _get_projector_vectorizer(projection, vectorization, kernel=None): elif method == 'riemann_wasserstein': steps = (ProjIdentitySpace, RiemannSnp) + # add alignment options + alignment_steps = { + 'recenter': None, + 'rescale': None + } + if isinstance(alignment, list) and 're-center' in alignment: + alignment_steps['recenter'] = ReCenter + if isinstance(alignment, list) and 're-scale' in alignment: + alignment_steps['rescale'] = ReScale + # add Kernel options if (isinstance(kernel, Pipeline) and not isinstance(kernel, (BaseEstimator, TransformerMixin))): @@ -295,7 +332,8 @@ def _get_projector_vectorizer(projection, vectorization, kernel=None): combine_kernels = 'sum' filter_bank_transformer = make_column_transformer( - *_get_projector_vectorizer(*steps, kernel=kernel), + *_get_projector_vectorizer(*steps, **alignment_steps, + kernel=kernel), remainder='passthrough' ) @@ -314,6 +352,7 @@ def _get_projector_vectorizer(projection, vectorization, kernel=None): def make_filter_bank_regressor( names: list[str], method: str = 'riemann', + alignment: Union[list[str], None] = None, projection_params: Union[dict, None] = None, vectorization_params: Union[dict, None] = None, categorical_interaction: Union[bool, None] = None, @@ -356,6 +395,9 @@ def make_filter_bank_regressor( to ``'riemann'``. Can be ``'riemann'``, ``'lw_riemann'``, ``'diag'``, ``'log_diag'``, ``'random'``, ``'naive'``, ``'spoc'``, ``'riemann_wasserstein'``. + alignment : list of str | None + Alignment steps to include in the pipeline. Can be ``'re-center'``, + ``'re-scale'``. projection_params : dict | None The parameters for the projection step. vectorization_params : dict | None @@ -379,7 +421,8 @@ def make_filter_bank_regressor( https://doi.org/10.1016/j.neuroimage.2020.116893 """ filter_bank_transformer = make_filter_bank_transformer( - names=names, method=method, projection_params=projection_params, + names=names, method=method, alignment=alignment, + projection_params=projection_params, vectorization_params=vectorization_params, categorical_interaction=categorical_interaction ) @@ -404,6 +447,7 @@ def make_filter_bank_regressor( def make_filter_bank_classifier( names: list[str], method: str = 'riemann', + alignment: Union[list[str], None] = None, projection_params: Union[dict, None] = None, vectorization_params: Union[dict, None] = None, categorical_interaction: Union[bool, None] = None, @@ -446,6 +490,9 @@ def make_filter_bank_classifier( to ``'riemann'``. Can be ``'riemann'``, ``'lw_riemann'``, ``'diag'``, ``'log_diag'``, ``'random'``, ``'naive'``, ``'spoc'``, ``'riemann_wasserstein'``. + alignment : list of str | None + Alignment steps to include in the pipeline. Can be ``'re-center'``, + ``'re-scale'``. projection_params : dict | None The parameters for the projection step. vectorization_params : dict | None @@ -470,7 +517,8 @@ def make_filter_bank_classifier( """ filter_bank_transformer = make_filter_bank_transformer( - names=names, method=method, projection_params=projection_params, + names=names, method=method, alignment=alignment, + projection_params=projection_params, vectorization_params=vectorization_params, categorical_interaction=categorical_interaction ) diff --git a/coffeine/tests/test_transfer_learning.py b/coffeine/tests/test_transfer_learning.py new file mode 100644 index 0000000..82e1bb2 --- /dev/null +++ b/coffeine/tests/test_transfer_learning.py @@ -0,0 +1,63 @@ +import pytest +import numpy as np +from pyriemann.datasets import make_classification_transfer +from pyriemann.transfer import decode_domains +from pyriemann.utils.mean import mean_covariance +from pyriemann.utils.distance import distance + +from coffeine.transfer_learning import ReCenter, ReScale + + +def test_recenter(): + n_matrices = 200 + domain_sep = 5 + X, y_enc = make_classification_transfer(n_matrices, domain_sep=domain_sep, + random_state=42) + _, y, domains = decode_domains(X, y_enc) + train_index = [ + i for i in range(len(domains)) if domains[i] != 'target_domain' + ] + test_index = [ + i for i in range(len(domains)) if domains[i] == 'target_domain' + ] + X_train, y_train = X[train_index], y[train_index] + X_test = X[test_index] + rct = ReCenter(metric='riemann') + X_train_rct = rct.fit_transform(X_train, y_train, + domains=domains[train_index]) + X_test_rct = rct.transform(X_test, domains=domains[test_index]) + # Test if mean is Identity + M_train = mean_covariance(X_train_rct, metric='riemann') + assert M_train == pytest.approx(np.eye(2)) + M_test = mean_covariance(X_test_rct, metric='riemann') + assert M_test == pytest.approx(np.eye(2)) + + +def test_rescale(): + n_matrices = 100 + stretch = 3 + X, y_enc = make_classification_transfer(n_matrices, stretch=stretch, + random_state=42) + _, y, domains = decode_domains(X, y_enc) + train_index = [ + i for i in range(len(domains)) if domains[i] != 'target_domain' + ] + test_index = [ + i for i in range(len(domains)) if domains[i] == 'target_domain' + ] + X_train, y_train = X[train_index], y[train_index] + X_test = X[test_index] + str = ReScale(metric='riemann') + X_train_str = str.fit_transform(X_train, y_train, domains[train_index]) + X_test_str = str.transform(X_test, domains[test_index]) + # Test if dispersion = 1 + M_train = mean_covariance(X_train_str, metric='riemann') + disp_train = np.mean( + distance(X_train_str, M_train, metric='riemann')**2 + ) + assert np.isclose(disp_train, 1.0) + M_test = mean_covariance(X_test_str, metric='riemann') + disp_test = np.mean( + distance(X_test_str, M_test, metric='riemann')**2 + ) + assert np.isclose(disp_test, 1.0) diff --git a/coffeine/transfer_learning.py b/coffeine/transfer_learning.py new file mode 100644 index 0000000..2ffada2 --- /dev/null +++ b/coffeine/transfer_learning.py @@ -0,0 +1,285 @@ +import numpy as np +import pyriemann +from pyriemann.transfer._tools import decode_domains +from pyriemann.utils.mean import mean_riemann +from pyriemann.utils.distance import distance +from pyriemann.transfer import TLCenter, TLStretch, encode_domains +from sklearn.base import BaseEstimator, TransformerMixin + + +def _check_data(X): + # make proper 3d array of covariances + out = None + if X.ndim == 3: + out = X + elif X.values.dtype == 'object': + # first remove unnecessary dimensions, + # then stack to 3d data + values = X.values + if values.shape[1] == 1: + values = values[:, 0] + out = np.stack(values) + if out.ndim == 2: # deal with single sample + assert out.shape[0] == out.shape[1] + out = out[np.newaxis, :, :] + return out + + +def _check_domains(domains, n_sample, target=False): + out = None + if domains is None: + if not target: + out = ['source_domain']*n_sample + else: + out = ['target_domain']*n_sample + else: + if target and 'target_domain' not in domains: + raise ValueError( + "The target domains should include 'target_domain'" + ) + else: + out = domains + return out + + +class TLStretch_patch(TLStretch): + """Patched function of TLStretch. + + To use in ReScale when pyRiemann version is lower than 0.6""" + + def __init__(self, target_domain, final_dispersion=1.0, + centered_data=False, metric='riemann'): + super().__init__(target_domain, final_dispersion, + centered_data, metric) + + def fit(self, X, y_enc): + """Fit TLStretch_patch. + + Calculate the dispersion around the mean for each domain. + + Parameters + ---------- + X : ndarray, shape (n_matrices, n_channels, n_channels) + Set of SPD matrices. + y_enc : ndarray, shape (n_matrices,) + Extended labels for each matrix. + + Returns + ------- + self : TLStretch_patch instance + The TLStretch_patch instance. + """ + + _, _, domains = decode_domains(X, y_enc) + n_dim = X[0].shape[1] + self._means = {} + self.dispersions_ = {} + for d in np.unique(domains): + if self.centered_data: + self._means[d] = np.eye(n_dim) + else: + self._means[d] = mean_riemann(X[domains == d]) + disp_domain = distance( + X[domains == d], + self._means[d], + metric=self.metric, + squared=True, + ).mean() + self.dispersions_[d] = disp_domain + + return self + + +class ReCenter(BaseEstimator, TransformerMixin): + """Re-center each dataset seperately for transfer learning. + + The data from each dataset are re-centered to the Identity using TLCenter + from Pyriemann. The difference is we assume to not have target data in the + training sample so when the transform function is called we fit_transform + TLCenter on the target (test) data. + + Parameters + ---------- + metric : str, default='riemann' + The metric to compute the mean. + """ + def __init__(self, metric='riemann'): + self.metric = metric + + def fit(self, X, y, domains=None): + """Fit ReCenter. + + Mean of each domain are calculated with TLCenter from + pyRiemann. + + Parameters + ---------- + X : ndarray, shape (n_matrices, n_channels, n_channels) + Set of SPD matrices. + y: ndarray, shape (n_matrices,) + Labels for each matrix. + Returns + ------- + self : ReCenter instance + """ + X = _check_data(X) + domains = _check_domains(domains, X.shape[0], target=False) + self._domains_source = domains + _, y_enc = encode_domains(X, y, domains) + self.re_center_ = TLCenter('target_domain', metric=self.metric) + self.means_ = self.re_center_.fit(X, y_enc).recenter_ + return self + + def transform(self, X, y=None, domains=None): + """Re-center the test data. + + Calculate the mean and then transform the data. + It is assumed that data points in X are all from the same domain + and that this domain is not present in the data used in fit. + + Parameters + ---------- + X : ndarray, shape (n_matrices, n_channels, n_channels) + Set of SPD matrices. + y: ndarray, shape (n_matrices,) + Labels for each matrix. + Returns + ------- + X_rct : ndarray, shape (n_matrices, n_channels, n_channels) + Set of SPD matrices with mean at the Identity. + """ + X = _check_data(X) + n_sample = X.shape[0] + domains = _check_domains(domains, n_sample, target=True) + _, y_enc = encode_domains(X, [0]*n_sample, domains) + # self.re_center_ = TLCenter('target_domain', metric=self.metric) + if 'target_domain' in self._domains_source: + X_rct = self.re_center_.transform(X, y_enc) + else: + X_rct = self.re_center_.fit_transform(X, y_enc) + return X_rct + + def fit_transform(self, X, y, domains=None): + """Fit ReCenter and transform the data. + + Calculate the mean of each domain with TLCenter from pyRiemann and + then transform the data. + + Parameters + ---------- + X : ndarray, shape (n_matrices, n_channels, n_channels) + Set of SPD matrices. + y: ndarray, shape (n_matrices,) + Labels for each matrix. + Returns + ------- + X_rct : ndarray, shape (n_matrices, n_channels, n_channels) + Set of SPD matrices with mean at the Identity. + """ + self.fit(X, y, domains) + X = _check_data(X) + domains = _check_domains(domains, X.shape[0], target=False) + _, y_enc = encode_domains(X, y, domains) + X_rct = self.re_center_.fit_transform(X, y_enc) + return X_rct + + +class ReScale(BaseEstimator, TransformerMixin): + """Re-scale each dataset seperately for transfer learning. + + The data from each dataset are re-scaled to the Identity using TLStretch + from Pyriemann. The difference is we assume to not have target data in the + training sample so when the transform function is called we fit_transform + TLStretch on the target (test) data. It is also assumed that the data were + re-centered beforehand. + + Parameters + ---------- + metric : str, default='riemann' + The metric to compute the dispersion. + """ + def __init__(self, metric='riemann'): + self.metric = metric + + def fit(self, X, y, domains): + """Fit ReScale. + + Dispersions around the mean of each domain are calculated with + TLStretch from pyRiemann. + + Parameters + ---------- + X : ndarray, shape (n_matrices, n_channels, n_channels) + Set of SPD matrices. + y: ndarray, shape (n_matrices,) + Labels for each matrix. + Returns + ------- + self : ReScale instance + """ + X = _check_data(X) + domains = _check_domains(domains, X.shape[0], target=False) + self._domains_source = domains + _, y_enc = encode_domains(X, y, domains) + if pyriemann.__version__ != '0.6': + self.re_scale_ = TLStretch_patch( + 'target_domain', centered_data=False, metric=self.metric + ) + else: + self.re_scale_ = TLStretch( + 'target_domain', centered_data=False, metric=self.metric + ) + self.dispersions_ = self.re_scale_.fit(X, y_enc).dispersions_ + return self + + def transform(self, X, y=None, domains=None): + """Re-scale the test data. + + Calculate the dispersion around the mean iand then transform the data. + It is assumed that data points in X are all from the same domain + and that this domain is not present in the data used in fit. + + Parameters + ---------- + X : ndarray, shape (n_matrices, n_channels, n_channels) + Set of SPD matrices. + y: ndarray, shape (n_matrices,) + Labels for each matrix. + Returns + ------- + X_str : ndarray, shape (n_matrices, n_channels, n_channels) + Set of SPD matrices with a dispersion equal to 1. + """ + X = _check_data(X) + n_sample = X.shape[0] + domains = _check_domains(domains, n_sample, target=True) + _, y_enc = encode_domains(X, [0]*n_sample, domains) + if 'target_domain' in self._domains_source: + X_str = self.re_scale_.transform(X) + else: + X_str = self.re_scale_.fit_transform(X, y_enc) + return X_str + + def fit_transform(self, X, y, domains): + """Fit ReScale and transform the data. + + Calculate the dispersions around the mean of each domain with + TLStretch from pyRiemann and then transform the data. + + Parameters + ---------- + X : ndarray, shape (n_matrices, n_channels, n_channels) + Set of SPD matrices. + y: ndarray, shape (n_matrices,) + Labels for each matrix. + Returns + ------- + X_str : ndarray, shape (n_matrices, n_channels, n_channels) + Set of SPD matrices with a dispersion equal to 1. + """ + self.fit(X, y, domains) + X = _check_data(X) + domains = _check_domains(domains, X.shape[0], target=False) + _, y_enc = encode_domains(X, y, domains) + X_str = self.re_scale_.fit_transform(X, y_enc) + return X_str diff --git a/doc/tutorials/filterbank_cross_subject_classification_bci.ipynb b/doc/tutorials/filterbank_cross_subject_classification_bci.ipynb new file mode 100644 index 0000000..31ea930 --- /dev/null +++ b/doc/tutorials/filterbank_cross_subject_classification_bci.ipynb @@ -0,0 +1,337 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Building a filterbank model with alignment\n", + "\n", + "This notebook is based on the [MNE example](https://mne.tools/dev/auto_examples/decoding/decoding_csp_eeg.html) and illustrates the construction of the filterbank models including alignment steps. Here, we perform cross-subject classification.\n", + "\n", + "First we load the data of two subjects from the EEGBCI dataset: one for train, that we refer as the source subject, and one for test, the target subject." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "\n", + "from sklearn.linear_model import LogisticRegression\n", + "from sklearn.model_selection import ShuffleSplit, cross_val_score\n", + "\n", + "import mne\n", + "from mne import Epochs, pick_types, events_from_annotations\n", + "from mne.io import concatenate_raws, read_raw_edf\n", + "from mne.datasets import eegbci\n", + "\n", + "from coffeine import compute_coffeine, make_filter_bank_classifier" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "mne.set_log_level('critical')\n", + "pd.set_option(\"large_repr\", \"info\")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "tmin, tmax = -1.0, 4.0\n", + "event_id = dict(hands=2, feet=3)\n", + "subject_source = 1\n", + "subject_target = 5\n", + "runs = [6, 10, 14] # motor imagery: hands vs feet\n", + "raw_fnames_source = eegbci.load_data(subject_source, runs)\n", + "raw_fnames_target = eegbci.load_data(subject_target, runs)\n", + "raw_source = concatenate_raws([read_raw_edf(f, preload=True) for f in raw_fnames_source])\n", + "raw_target = concatenate_raws([read_raw_edf(f, preload=True) for f in raw_fnames_target])\n", + "eegbci.standardize(raw_source) # set channel names\n", + "eegbci.standardize(raw_target) # set channel names" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "# Apply band-pass filter\n", + "raw_source.filter(4.0, 35.0, fir_design=\"firwin\", skip_by_annotation=\"edge\")\n", + "raw_target.filter(4.0, 35.0, fir_design=\"firwin\", skip_by_annotation=\"edge\")\n", + "\n", + "events_source, _ = events_from_annotations(raw_source, event_id=dict(T1=2, T2=3))\n", + "events_target, _ = events_from_annotations(raw_target, event_id=dict(T1=2, T2=3))\n", + "picks_source = pick_types(raw_source.info, meg=False, eeg=True, stim=False, eog=False, exclude=\"bads\")\n", + "picks_target = pick_types(raw_target.info, meg=False, eeg=True, stim=False, eog=False, exclude=\"bads\")\n", + "\n", + "# Read epochs (train will be done only between 1 and 2s)\n", + "# Testing will be done with a running classifier\n", + "epochs_source = Epochs(\n", + " raw_source,\n", + " events_source,\n", + " event_id,\n", + " tmin,\n", + " tmax,\n", + " proj=True,\n", + " picks=picks_source,\n", + " baseline=None,\n", + " preload=True,\n", + ")\n", + "epochs_target = Epochs(\n", + " raw_target,\n", + " events_target,\n", + " event_id,\n", + " tmin,\n", + " tmax,\n", + " proj=True,\n", + " picks=picks_target,\n", + " baseline=None,\n", + " preload=True,\n", + ")\n", + "\n", + "labels_source = epochs_source.events[:, -1] - 2\n", + "labels_target = epochs_target.events[:, -1] - 2\n", + "conditions = ['feet', 'hand']" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Building seperate coffeine dataframes for source and target data\n", + "\n", + "Covariances are computed on pre-defined frequency bands for each subject and dataframes are created with columns corresponding to the frequency bands. The elements of the dataframes are the covariances." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
<class 'pandas.core.frame.DataFrame'>\n",
+       "RangeIndex: 5 entries, 0 to 4\n",
+       "Data columns (total 2 columns):\n",
+       " #   Column  Non-Null Count  Dtype \n",
+       "---  ------  --------------  ----- \n",
+       " 0   alpha1  5 non-null      object\n",
+       " 1   alpha2  5 non-null      object\n",
+       "dtypes: object(2)\n",
+       "memory usage: 212.0+ bytes\n",
+       "
" + ], + "text/plain": [ + "\n", + "RangeIndex: 5 entries, 0 to 4\n", + "Data columns (total 2 columns):\n", + " # Column Non-Null Count Dtype \n", + "--- ------ -------------- ----- \n", + " 0 alpha1 5 non-null object\n", + " 1 alpha2 5 non-null object\n", + "dtypes: object(2)\n", + "memory usage: 212.0+ bytes" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "X_df_source, feature_info_source = compute_coffeine(epochs_source, frequencies=('ipeg', ['alpha1', 'alpha2']))\n", + "X_df_target, feature_info_target = compute_coffeine(epochs_target, frequencies=('ipeg', ['alpha1', 'alpha2']))\n", + "X_df_source.head()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Comparing classification accuracy with and without alignment" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We first construct a model without alignment steps as done in " + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "fb_model = make_filter_bank_classifier(\n", + " names=list(X_df_source.columns),\n", + " method='riemann',\n", + " projection_params=dict(scale=1, n_compo=60, reg=0),\n", + " estimator=LogisticRegression(solver='liblinear', C=1e7)\n", + ")\n", + "fb_model.fit(X_df_source, labels_source)\n", + "score = fb_model.score(X_df_target, labels_target)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.5333333333333333" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "score" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "fb_model = make_filter_bank_classifier(\n", + " names=list(X_df_source.columns),\n", + " method='riemann',\n", + " alignment=['re-center', 're-scale'],\n", + " # domains=['source']*X_df_source.shape[0],\n", + " projection_params=dict(scale=1, n_compo=60, reg=0),\n", + " estimator=LogisticRegression(solver='liblinear', C=1e7)\n", + ")\n", + "fb_model.fit(X_df_source, labels_source, domains=['source']*X_df_source.shape[0])\n", + "score = fb_model.score(X_df_target, labels_target, domains=['target_domain']*X_df_target.shape[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.6222222222222222" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "score" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "cv = ShuffleSplit(10, test_size=0.8, random_state=42)\n", + "scores = []\n", + "for train_index, test_index in cv.split(X_df_target):\n", + " X_df_target_train = X_df_target.iloc[train_index]\n", + " labels_target_train = labels_target[train_index]\n", + " X_df_target_test = X_df_target.iloc[test_index]\n", + " labels_target_test = labels_target[test_index]\n", + " X_df_train = pd.concat([X_df_source, X_df_target_train])\n", + " y_train = np.concatenate([labels_source, labels_target_train])\n", + " domains = ['source']*X_df_source.shape[0] + ['target_domain']*X_df_target_train.shape[0]\n", + " fb_model.fit(X_df_train, y_train, domains=domains)\n", + " scores.append(fb_model.score(X_df_target_test, labels_target_test,\n", + " domains=['target_domain']*X_df_target_test.shape[0]))" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Mean classification accuracy: 0.61\n" + ] + } + ], + "source": [ + "print(f'Mean classification accuracy: {np.mean(scores):0.2f}')" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[0.75,\n", + " 0.5833333333333334,\n", + " 0.6388888888888888,\n", + " 0.6666666666666666,\n", + " 0.4722222222222222,\n", + " 0.5833333333333334,\n", + " 0.6388888888888888,\n", + " 0.6111111111111112,\n", + " 0.5,\n", + " 0.6944444444444444]" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "scores" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "dameeg", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.1" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/requirements.txt b/requirements.txt index 44fd1b4..307f3ba 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,6 +2,6 @@ numpy>=1.18.1 scipy>=1.4.1 matplotlib>=2.0.0 pandas>=1.0.0 -pyriemann>=0.4 +pyriemann>=0.5 scikit-learn>=1.0 mne[data]>=1.0 \ No newline at end of file