Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Include alignment steps in coffeine pipeline #58

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
2 changes: 2 additions & 0 deletions coffeine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,5 @@
from .power_features import compute_features, get_frequency_bands, compute_coffeine, make_coffeine_data_frame # noqa

from .spatial_filters import ProjIdentitySpace, ProjCommonSpace, ProjLWSpace, ProjRandomSpace, ProjSPoCSpace # noqa

from .transfer_learning import ReCenter, ReScale # noqa
42 changes: 36 additions & 6 deletions coffeine/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@
ProjRandomSpace,
ProjSPoCSpace)

from coffeine.transfer_learning import (
ReCenter,
ReScale
)

from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.compose import make_column_transformer
from sklearn.pipeline import make_pipeline, Pipeline
Expand Down Expand Up @@ -142,6 +147,8 @@ def transform(self,
def make_filter_bank_transformer(
names: list[str],
method: str = 'riemann',
alignment: Union[list[str], None] = None,
domains: Union[list[str], None] = None,
projection_params: Union[dict, None] = None,
vectorization_params: Union[dict, None] = None,
kernel: Union[str, Pipeline, None] = None,
Expand Down Expand Up @@ -184,6 +191,9 @@ def make_filter_bank_transformer(
to ``'riemann'``. Can be ``'riemann'``, ``'lw_riemann'``, ``'diag'``,
``'log_diag'``, ``'random'``, ``'naive'``, ``'spoc'``,
``'riemann_wasserstein'``.
alignment : list of str | None
Alignment steps to include in the pipeline. Can be ``'re-center'``,
``'re-scale'``.
projection_params : dict | None
The parameters for the projection step.
vectorization_params : dict | None
Expand Down Expand Up @@ -249,13 +259,22 @@ def make_filter_bank_transformer(
if vectorization_params is not None:
vectorization_params_.update(**vectorization_params)

def _get_projector_vectorizer(projection, vectorization, kernel=None):
def _get_projector_vectorizer(projection, vectorization,
alignment_steps=None,
kernel=None):
out = list()
for name in names:
steps = [
projection(**projection_params_),
vectorization(**vectorization_params_)
]
if alignment_steps is None:
steps = [
projection(**projection_params_),
vectorization(**vectorization_params_)
]
else:
steps = [
projection(**projection_params_)
] + alignment_steps + [
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we be sure that this not buggy? We need to make sure that for every column (‘name’) independent instances are instantiated. Otherwise you risk some stateful behavior.

vectorization(**vectorization_params_)
]
if kernel is not None:
kernel_name, kernel_estimator = kernel
steps.append(kernel_estimator())
Expand All @@ -282,6 +301,16 @@ def _get_projector_vectorizer(projection, vectorization, kernel=None):
elif method == 'riemann_wasserstein':
steps = (ProjIdentitySpace, RiemannSnp)

# add alignment options
alignment_steps = []
if alignment is None:
alignment_steps = None
else:
if 're-center' in alignment:
alignment_steps.append(ReCenter(domains=domains))
if 're-scale' in alignment:
alignment_steps.append(ReScale(domains=domains))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See comment above. We would use the same ReCenter/ReScale instances for every column of the data frame based on this code here, right?

Copy link
Collaborator Author

@apmellot apmellot Aug 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand, isn't it the same thing as what is done a few lines above for the projection and vectorization steps?
Also in practice I checked and each frequency band is aligned independently.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It’s not the same as we instantiate the classes independently for every column. We can check of course if this is practically unneccessary by now because of internal cloning etc. But if you want to be consistent with the tested / save code you can just adopt the same pattern as for the other projection/vectorization steps. You currently instantiate the classes globally and plug in the same input column by column.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok I see, let's do it as projection/vectorization.


# add Kernel options
if (isinstance(kernel, Pipeline) and not
isinstance(kernel, (BaseEstimator, TransformerMixin))):
Expand All @@ -295,7 +324,8 @@ def _get_projector_vectorizer(projection, vectorization, kernel=None):
combine_kernels = 'sum'

filter_bank_transformer = make_column_transformer(
*_get_projector_vectorizer(*steps, kernel=kernel),
*_get_projector_vectorizer(*steps, alignment_steps=alignment_steps,
kernel=kernel),
remainder='passthrough'
)

Expand Down
200 changes: 200 additions & 0 deletions coffeine/transfer_learning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
import numpy as np
from pyriemann.transfer import TLCenter, TLStretch, encode_domains
from sklearn.base import BaseEstimator, TransformerMixin


def _check_data(X):
# make proper 3d array of covariances
out = None
if X.ndim == 3:
out = X
elif X.values.dtype == 'object':
# first remove unnecessary dimensions,
# then stack to 3d data
values = X.values
if values.shape[1] == 1:
values = values[:, 0]
out = np.stack(values)
if out.ndim == 2: # deal with single sample
assert out.shape[0] == out.shape[1]
out = out[np.newaxis, :, :]
return out


class ReCenter(BaseEstimator, TransformerMixin):
"""Re-center each dataset seperately for transfer learning.

The data from each dataset are re-centered to the Identity using TLCenter
from Pyriemann. The difference is we assume to not have target data in the
training sample so when the transform function is called we fit_transform
TLCenter on the target (test) data.

Parameters
----------
metric : str, default='riemann'
The metric to compute the mean.
"""
def __init__(self, domains, metric='riemann'):
self.domains = domains
self.metric = metric

def fit(self, X, y):
"""Fit ReCenter.

Mean of each domain are calculated with TLCenter from
pyRiemann.

Parameters
----------
X : ndarray, shape (n_matrices, n_channels, n_channels)
Set of SPD matrices.
y: ndarray, shape (n_matrices,)
Labels for each matrix.
Returns
-------
self : ReCenter instance
"""
X = _check_data(X)
_, y_enc = encode_domains(X, y, self.domains)
self.re_center_ = TLCenter('target_domain', metric=self.metric)
self.means_ = self.re_center_.fit(X, y_enc).recenter_
return self

def transform(self, X, y=None):
"""Re-center the test data.

Calculate the mean and then transform the data.
It is assumed that data points in X are all from the same domain
and that this domain is not present in the data used in fit.

Parameters
----------
X : ndarray, shape (n_matrices, n_channels, n_channels)
Set of SPD matrices.
y: ndarray, shape (n_matrices,)
Labels for each matrix.
Returns
-------
X_rct : ndarray, shape (n_matrices, n_channels, n_channels)
Set of SPD matrices with mean at the Identity.
"""
X = _check_data(X)
n_sample = X.shape[0]
_, y_enc = encode_domains(X, [0]*n_sample, ['target_domain']*n_sample)
X_rct = self.re_center_.fit_transform(X, y_enc)
return X_rct

def fit_transform(self, X, y):
"""Fit ReCenter and transform the data.

Calculate the mean of each domain with TLCenter from pyRiemann and
then transform the data.

Parameters
----------
X : ndarray, shape (n_matrices, n_channels, n_channels)
Set of SPD matrices.
y: ndarray, shape (n_matrices,)
Labels for each matrix.
Returns
-------
X_rct : ndarray, shape (n_matrices, n_channels, n_channels)
Set of SPD matrices with mean at the Identity.
"""
X = _check_data(X)
_, y_enc = encode_domains(X, y, self.domains)
self.re_center_ = TLCenter('target_domain', metric=self.metric)
X_rct = self.re_center_.fit_transform(X, y_enc)
return X_rct


class ReScale(BaseEstimator, TransformerMixin):
"""Re-scale each dataset seperately for transfer learning.

The data from each dataset are re-scaled to the Identity using TLStretch
from Pyriemann. The difference is we assume to not have target data in the
training sample so when the transform function is called we fit_transform
TLStretch on the target (test) data. It is also assumed that the data were
re-centered beforehand.

Parameters
----------
metric : str, default='riemann'
The metric to compute the dispersion.
"""
def __init__(self, domains, metric='riemann'):
self.domains = domains
self.metric = metric

def fit(self, X, y):
"""Fit ReScale.

Dispersions around the mean of each domain are calculated with
TLStretch from pyRiemann.

Parameters
----------
X : ndarray, shape (n_matrices, n_channels, n_channels)
Set of SPD matrices.
y: ndarray, shape (n_matrices,)
Labels for each matrix.
Returns
-------
self : ReScale instance
"""
X = _check_data(X)
_, y_enc = encode_domains(X, y, self.domains)
self.re_scale_ = TLStretch('target_domain',
centered_data=True,
metric=self.metric)
self.dispersions_ = self.re_scale_.fit(X, y_enc).dispersions_
return self

def transform(self, X, y=None):
apmellot marked this conversation as resolved.
Show resolved Hide resolved
"""Re-scale the test data.

Calculate the dispersion around the mean iand then transform the data.
It is assumed that data points in X are all from the same domain
and that this domain is not present in the data used in fit.

Parameters
----------
X : ndarray, shape (n_matrices, n_channels, n_channels)
Set of SPD matrices.
y: ndarray, shape (n_matrices,)
Labels for each matrix.
Returns
-------
X_str : ndarray, shape (n_matrices, n_channels, n_channels)
Set of SPD matrices with a dispersion equal to 1.
"""
X = _check_data(X)
n_sample = X.shape[0]
_, y_enc = encode_domains(X, [0]*n_sample, ['target_domain']*n_sample)
X_str = self.re_scale_.fit_transform(X, y_enc)
return X_str

def fit_transform(self, X, y):
"""Fit ReScale and transform the data.

Calculate the dispersions around the mean of each domain with
TLStretch from pyRiemann and then transform the data.

Parameters
----------
X : ndarray, shape (n_matrices, n_channels, n_channels)
Set of SPD matrices.
y: ndarray, shape (n_matrices,)
Labels for each matrix.
Returns
-------
X_str : ndarray, shape (n_matrices, n_channels, n_channels)
Set of SPD matrices with a dispersion equal to 1.
"""
X = _check_data(X)
_, y_enc = encode_domains(X, y, self.domains)
self.re_scale_ = TLStretch('target_domain',
centered_data=True,
metric=self.metric)
X_str = self.re_scale_.fit_transform(X, y_enc)
return X_str