-
Notifications
You must be signed in to change notification settings - Fork 13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Include alignment steps in coffeine pipeline #58
base: main
Are you sure you want to change the base?
Changes from 4 commits
14405aa
35f7e28
19a6c68
f5df3a1
c7c6f48
9eb3920
c95d96f
de35d1f
f2d5809
72f8967
50f5531
e933106
c08504f
9c0491d
1cea3fb
ef0830c
56d10e9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,6 +16,11 @@ | |
ProjRandomSpace, | ||
ProjSPoCSpace) | ||
|
||
from coffeine.transfer_learning import ( | ||
ReCenter, | ||
ReScale | ||
) | ||
|
||
from sklearn.base import BaseEstimator, TransformerMixin | ||
from sklearn.compose import make_column_transformer | ||
from sklearn.pipeline import make_pipeline, Pipeline | ||
|
@@ -142,6 +147,8 @@ def transform(self, | |
def make_filter_bank_transformer( | ||
names: list[str], | ||
method: str = 'riemann', | ||
alignment: Union[list[str], None] = None, | ||
domains: Union[list[str], None] = None, | ||
projection_params: Union[dict, None] = None, | ||
vectorization_params: Union[dict, None] = None, | ||
kernel: Union[str, Pipeline, None] = None, | ||
|
@@ -184,6 +191,9 @@ def make_filter_bank_transformer( | |
to ``'riemann'``. Can be ``'riemann'``, ``'lw_riemann'``, ``'diag'``, | ||
``'log_diag'``, ``'random'``, ``'naive'``, ``'spoc'``, | ||
``'riemann_wasserstein'``. | ||
alignment : list of str | None | ||
Alignment steps to include in the pipeline. Can be ``'re-center'``, | ||
``'re-scale'``. | ||
projection_params : dict | None | ||
The parameters for the projection step. | ||
vectorization_params : dict | None | ||
|
@@ -249,13 +259,22 @@ def make_filter_bank_transformer( | |
if vectorization_params is not None: | ||
vectorization_params_.update(**vectorization_params) | ||
|
||
def _get_projector_vectorizer(projection, vectorization, kernel=None): | ||
def _get_projector_vectorizer(projection, vectorization, | ||
alignment_steps=None, | ||
kernel=None): | ||
out = list() | ||
for name in names: | ||
steps = [ | ||
projection(**projection_params_), | ||
vectorization(**vectorization_params_) | ||
] | ||
if alignment_steps is None: | ||
steps = [ | ||
projection(**projection_params_), | ||
vectorization(**vectorization_params_) | ||
] | ||
else: | ||
steps = [ | ||
projection(**projection_params_) | ||
] + alignment_steps + [ | ||
vectorization(**vectorization_params_) | ||
] | ||
if kernel is not None: | ||
kernel_name, kernel_estimator = kernel | ||
steps.append(kernel_estimator()) | ||
|
@@ -282,6 +301,16 @@ def _get_projector_vectorizer(projection, vectorization, kernel=None): | |
elif method == 'riemann_wasserstein': | ||
steps = (ProjIdentitySpace, RiemannSnp) | ||
|
||
# add alignment options | ||
alignment_steps = [] | ||
if alignment is None: | ||
alignment_steps = None | ||
else: | ||
if 're-center' in alignment: | ||
alignment_steps.append(ReCenter(domains=domains)) | ||
if 're-scale' in alignment: | ||
alignment_steps.append(ReScale(domains=domains)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See comment above. We would use the same ReCenter/ReScale instances for every column of the data frame based on this code here, right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't understand, isn't it the same thing as what is done a few lines above for the projection and vectorization steps? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It’s not the same as we instantiate the classes independently for every column. We can check of course if this is practically unneccessary by now because of internal cloning etc. But if you want to be consistent with the tested / save code you can just adopt the same pattern as for the other projection/vectorization steps. You currently instantiate the classes globally and plug in the same input column by column. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok I see, let's do it as projection/vectorization. |
||
|
||
# add Kernel options | ||
if (isinstance(kernel, Pipeline) and not | ||
isinstance(kernel, (BaseEstimator, TransformerMixin))): | ||
|
@@ -295,7 +324,8 @@ def _get_projector_vectorizer(projection, vectorization, kernel=None): | |
combine_kernels = 'sum' | ||
|
||
filter_bank_transformer = make_column_transformer( | ||
*_get_projector_vectorizer(*steps, kernel=kernel), | ||
*_get_projector_vectorizer(*steps, alignment_steps=alignment_steps, | ||
kernel=kernel), | ||
remainder='passthrough' | ||
) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,200 @@ | ||
import numpy as np | ||
from pyriemann.transfer import TLCenter, TLStretch, encode_domains | ||
from sklearn.base import BaseEstimator, TransformerMixin | ||
|
||
|
||
def _check_data(X): | ||
# make proper 3d array of covariances | ||
out = None | ||
if X.ndim == 3: | ||
out = X | ||
elif X.values.dtype == 'object': | ||
# first remove unnecessary dimensions, | ||
# then stack to 3d data | ||
values = X.values | ||
if values.shape[1] == 1: | ||
values = values[:, 0] | ||
out = np.stack(values) | ||
if out.ndim == 2: # deal with single sample | ||
assert out.shape[0] == out.shape[1] | ||
out = out[np.newaxis, :, :] | ||
return out | ||
|
||
|
||
class ReCenter(BaseEstimator, TransformerMixin): | ||
"""Re-center each dataset seperately for transfer learning. | ||
|
||
The data from each dataset are re-centered to the Identity using TLCenter | ||
from Pyriemann. The difference is we assume to not have target data in the | ||
training sample so when the transform function is called we fit_transform | ||
TLCenter on the target (test) data. | ||
|
||
Parameters | ||
---------- | ||
metric : str, default='riemann' | ||
The metric to compute the mean. | ||
""" | ||
def __init__(self, domains, metric='riemann'): | ||
self.domains = domains | ||
self.metric = metric | ||
|
||
def fit(self, X, y): | ||
"""Fit ReCenter. | ||
|
||
Mean of each domain are calculated with TLCenter from | ||
pyRiemann. | ||
|
||
Parameters | ||
---------- | ||
X : ndarray, shape (n_matrices, n_channels, n_channels) | ||
Set of SPD matrices. | ||
y: ndarray, shape (n_matrices,) | ||
Labels for each matrix. | ||
Returns | ||
------- | ||
self : ReCenter instance | ||
""" | ||
X = _check_data(X) | ||
_, y_enc = encode_domains(X, y, self.domains) | ||
self.re_center_ = TLCenter('target_domain', metric=self.metric) | ||
self.means_ = self.re_center_.fit(X, y_enc).recenter_ | ||
return self | ||
|
||
def transform(self, X, y=None): | ||
"""Re-center the test data. | ||
|
||
Calculate the mean and then transform the data. | ||
It is assumed that data points in X are all from the same domain | ||
and that this domain is not present in the data used in fit. | ||
|
||
Parameters | ||
---------- | ||
X : ndarray, shape (n_matrices, n_channels, n_channels) | ||
Set of SPD matrices. | ||
y: ndarray, shape (n_matrices,) | ||
Labels for each matrix. | ||
Returns | ||
------- | ||
X_rct : ndarray, shape (n_matrices, n_channels, n_channels) | ||
Set of SPD matrices with mean at the Identity. | ||
""" | ||
X = _check_data(X) | ||
n_sample = X.shape[0] | ||
_, y_enc = encode_domains(X, [0]*n_sample, ['target_domain']*n_sample) | ||
X_rct = self.re_center_.fit_transform(X, y_enc) | ||
return X_rct | ||
|
||
def fit_transform(self, X, y): | ||
"""Fit ReCenter and transform the data. | ||
|
||
Calculate the mean of each domain with TLCenter from pyRiemann and | ||
then transform the data. | ||
|
||
Parameters | ||
---------- | ||
X : ndarray, shape (n_matrices, n_channels, n_channels) | ||
Set of SPD matrices. | ||
y: ndarray, shape (n_matrices,) | ||
Labels for each matrix. | ||
Returns | ||
------- | ||
X_rct : ndarray, shape (n_matrices, n_channels, n_channels) | ||
Set of SPD matrices with mean at the Identity. | ||
""" | ||
X = _check_data(X) | ||
_, y_enc = encode_domains(X, y, self.domains) | ||
self.re_center_ = TLCenter('target_domain', metric=self.metric) | ||
X_rct = self.re_center_.fit_transform(X, y_enc) | ||
return X_rct | ||
|
||
|
||
class ReScale(BaseEstimator, TransformerMixin): | ||
"""Re-scale each dataset seperately for transfer learning. | ||
|
||
The data from each dataset are re-scaled to the Identity using TLStretch | ||
from Pyriemann. The difference is we assume to not have target data in the | ||
training sample so when the transform function is called we fit_transform | ||
TLStretch on the target (test) data. It is also assumed that the data were | ||
re-centered beforehand. | ||
|
||
Parameters | ||
---------- | ||
metric : str, default='riemann' | ||
The metric to compute the dispersion. | ||
""" | ||
def __init__(self, domains, metric='riemann'): | ||
self.domains = domains | ||
self.metric = metric | ||
|
||
def fit(self, X, y): | ||
"""Fit ReScale. | ||
|
||
Dispersions around the mean of each domain are calculated with | ||
TLStretch from pyRiemann. | ||
|
||
Parameters | ||
---------- | ||
X : ndarray, shape (n_matrices, n_channels, n_channels) | ||
Set of SPD matrices. | ||
y: ndarray, shape (n_matrices,) | ||
Labels for each matrix. | ||
Returns | ||
------- | ||
self : ReScale instance | ||
""" | ||
X = _check_data(X) | ||
_, y_enc = encode_domains(X, y, self.domains) | ||
self.re_scale_ = TLStretch('target_domain', | ||
centered_data=True, | ||
metric=self.metric) | ||
self.dispersions_ = self.re_scale_.fit(X, y_enc).dispersions_ | ||
return self | ||
|
||
def transform(self, X, y=None): | ||
apmellot marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Re-scale the test data. | ||
|
||
Calculate the dispersion around the mean iand then transform the data. | ||
It is assumed that data points in X are all from the same domain | ||
and that this domain is not present in the data used in fit. | ||
|
||
Parameters | ||
---------- | ||
X : ndarray, shape (n_matrices, n_channels, n_channels) | ||
Set of SPD matrices. | ||
y: ndarray, shape (n_matrices,) | ||
Labels for each matrix. | ||
Returns | ||
------- | ||
X_str : ndarray, shape (n_matrices, n_channels, n_channels) | ||
Set of SPD matrices with a dispersion equal to 1. | ||
""" | ||
X = _check_data(X) | ||
n_sample = X.shape[0] | ||
_, y_enc = encode_domains(X, [0]*n_sample, ['target_domain']*n_sample) | ||
X_str = self.re_scale_.fit_transform(X, y_enc) | ||
return X_str | ||
|
||
def fit_transform(self, X, y): | ||
"""Fit ReScale and transform the data. | ||
|
||
Calculate the dispersions around the mean of each domain with | ||
TLStretch from pyRiemann and then transform the data. | ||
|
||
Parameters | ||
---------- | ||
X : ndarray, shape (n_matrices, n_channels, n_channels) | ||
Set of SPD matrices. | ||
y: ndarray, shape (n_matrices,) | ||
Labels for each matrix. | ||
Returns | ||
------- | ||
X_str : ndarray, shape (n_matrices, n_channels, n_channels) | ||
Set of SPD matrices with a dispersion equal to 1. | ||
""" | ||
X = _check_data(X) | ||
_, y_enc = encode_domains(X, y, self.domains) | ||
self.re_scale_ = TLStretch('target_domain', | ||
centered_data=True, | ||
metric=self.metric) | ||
X_str = self.re_scale_.fit_transform(X, y_enc) | ||
return X_str |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we be sure that this not buggy? We need to make sure that for every column (‘name’) independent instances are instantiated. Otherwise you risk some stateful behavior.