-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Auto-label/order microstates maps (#105)
* Add optimize_order * Fix style * Fix style * fix tests * Fix style * Apply suggestions from code review Co-authored-by: Mathieu Scheltienne <[email protected]> * Add template option to base.reorder * Fix docstring * Fix docstrings * Update .gitignore * Update pycrostates/cluster/utils/utils.py * Update pycrostates/cluster/_base.py * Update pycrostates/cluster/_base.py * Update pycrostates/cluster/_base.py * add type-hint for Cluster class * fix typo --------- Co-authored-by: Mathieu Scheltienne <[email protected]> Co-authored-by: Mathieu Scheltienne <[email protected]>
- Loading branch information
1 parent
3dd164e
commit 91df686
Showing
12 changed files
with
240 additions
and
16 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
"""This module contains the clustering utils functions of ``pycrostates``.""" | ||
|
||
from .utils import optimize_order # noqa: F401 | ||
|
||
__all__ = ("optimize_order",) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
"""Test cluster utils.""" | ||
|
||
import numpy as np | ||
from mne.datasets import testing | ||
from mne.io import read_raw_fif | ||
|
||
from pycrostates.cluster import ModKMeans | ||
from pycrostates.cluster.utils.utils import _optimize_order, optimize_order | ||
|
||
directory = testing.data_path() / "MEG" / "sample" | ||
fname = directory / "sample_audvis_trunc_raw.fif" | ||
raw = read_raw_fif(fname, preload=False) | ||
raw.pick("eeg").crop(0, 10) | ||
raw.load_data() | ||
# Fit one for general purposes | ||
n_clusters = 5 | ||
ModK_0 = ModKMeans( | ||
n_clusters=n_clusters, n_init=10, max_iter=100, tol=1e-4, random_state=0 | ||
) | ||
ModK_0.fit(raw, n_jobs=1) | ||
|
||
ModK_1 = ModKMeans( | ||
n_clusters=n_clusters, n_init=10, max_iter=100, tol=1e-4, random_state=1 | ||
) | ||
ModK_1.fit(raw, n_jobs=1) | ||
|
||
|
||
def test__optimize_order(): | ||
template = ModK_1._cluster_centers_ | ||
n_states, n_electrodes = template.shape | ||
# Shuffle template | ||
arr = np.arange(n_states) | ||
np.random.shuffle(arr) | ||
random_template = template[arr] | ||
# invert polarity | ||
polarities = np.array([-1, 1, -1, 1, 1]) | ||
random_pol_template = polarities[:, np.newaxis] * random_template | ||
|
||
# No suffle | ||
current = template | ||
ignore_polarity = True | ||
order = _optimize_order(current, template, ignore_polarity=ignore_polarity) | ||
assert np.all(order == np.arange(n_states)) | ||
|
||
# Shuffle | ||
current = random_template | ||
ignore_polarity = False | ||
order = _optimize_order(current, template, ignore_polarity=ignore_polarity) | ||
assert np.allclose(current[order], template) | ||
|
||
# Shuffle + ignore_polarity | ||
current = random_template | ||
ignore_polarity = True | ||
order = _optimize_order(current, template, ignore_polarity=ignore_polarity) | ||
assert np.allclose(current[order], template) | ||
|
||
# Shuffle + sign + ignore_polarity | ||
current = random_pol_template | ||
ignore_polarity = True | ||
order_ = _optimize_order( | ||
current, template, ignore_polarity=ignore_polarity | ||
) | ||
assert np.all(order == order_) | ||
|
||
# Shuffle + sign | ||
current = random_pol_template | ||
ignore_polarity = False | ||
order = _optimize_order(current, template, ignore_polarity=ignore_polarity) | ||
corr = np.corrcoef(template, current[order])[n_states:, :n_states] | ||
corr_order = np.corrcoef(template, current[order])[n_states:, :n_states] | ||
assert np.trace(corr) <= np.trace(corr_order) | ||
|
||
|
||
def test_optimize_order(): | ||
order = optimize_order(ModK_0, ModK_1) | ||
assert np.all(np.sort(np.unique(order)) == np.arange(len(order))) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
import numpy as np | ||
import scipy | ||
from numpy.typing import NDArray | ||
|
||
from ..._typing import Cluster | ||
from ...utils._checks import _check_type | ||
from ...utils._docs import fill_doc | ||
|
||
|
||
def _optimize_order( | ||
centers: NDArray[float], | ||
template_centers: NDArray[float], | ||
ignore_polarity: bool = True, | ||
): | ||
n_states = len(centers) | ||
M = np.corrcoef(template_centers, centers)[:n_states, n_states:] | ||
if ignore_polarity: | ||
M = np.abs(M) | ||
_, order = scipy.optimize.linear_sum_assignment(-M) | ||
return order | ||
|
||
|
||
@fill_doc | ||
def optimize_order(inst: Cluster, template_inst: Cluster): | ||
"""Optimize the order of cluster centers between two cluster instances. | ||
Optimize the order of cluster centers in an instance of a clustering | ||
algorithm to maximize auto-correlation, based on a template instance | ||
as determined by the Hungarian algorithm. | ||
The two cluster instances must have the same number of cluster centers | ||
and the same polarity setting. | ||
Parameters | ||
---------- | ||
inst : :ref:`cluster` | ||
Fitted clustering algorithm to reorder. | ||
template_inst : :ref:`cluster` | ||
Fitted clustering algorithm to use as template for reordering. | ||
Returns | ||
------- | ||
order : list of int | ||
The new order to apply to inst to maximize auto-correlation | ||
of cluster centers. | ||
""" | ||
from .._base import _BaseCluster | ||
|
||
_check_type(inst, (_BaseCluster,), item_name="inst") | ||
inst._check_fit() | ||
_check_type(template_inst, (_BaseCluster,), item_name="template_inst") | ||
template_inst._check_fit() | ||
|
||
if inst.n_clusters != template_inst.n_clusters: | ||
raise ValueError( | ||
"Instance and the template must have the same " | ||
"number of cluster centers." | ||
) | ||
if inst._ignore_polarity != template_inst._ignore_polarity: | ||
raise ValueError( | ||
"Cannot find order: Instance was fitted using " | ||
f"ignore_polarity={inst._ignore_polarity} while " | ||
"template was fitted using ignore_polarity=" | ||
f"{inst._ignore_polarity} which could lead to " | ||
"misinterpretations." | ||
) | ||
inst_centers = inst._cluster_centers_ | ||
template_centers = template_inst._cluster_centers_ | ||
order = _optimize_order( | ||
inst_centers, template_centers, ignore_polarity=inst._ignore_polarity | ||
) | ||
return order |