From 91df686183c5db4d37814e90e5bbd84172804f2c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Victor=20F=C3=A9rat?= Date: Thu, 13 Apr 2023 17:15:45 +0200 Subject: [PATCH] Auto-label/order microstates maps (#105) * Add optimize_order * Fix style * Fix style * fix tests * Fix style * Apply suggestions from code review Co-authored-by: Mathieu Scheltienne * Add template option to base.reorder * Fix docstring * Fix docstrings * Update .gitignore * Update pycrostates/cluster/utils/utils.py * Update pycrostates/cluster/_base.py * Update pycrostates/cluster/_base.py * Update pycrostates/cluster/_base.py * add type-hint for Cluster class * fix typo --------- Co-authored-by: Mathieu Scheltienne Co-authored-by: Mathieu Scheltienne --- .gitignore | 1 + docs/source/api/cluster.rst | 14 ++++ pycrostates/_typing.py | 6 ++ pycrostates/cluster/__init__.py | 6 +- pycrostates/cluster/_base.py | 43 +++++++++-- pycrostates/cluster/aahc.py | 2 +- pycrostates/cluster/kmeans.py | 1 + pycrostates/cluster/tests/test_aahc.py | 10 ++- pycrostates/cluster/tests/test_kmeans.py | 21 ++++- pycrostates/cluster/utils/__init__.py | 5 ++ pycrostates/cluster/utils/tests/test_utils.py | 76 +++++++++++++++++++ pycrostates/cluster/utils/utils.py | 71 +++++++++++++++++ 12 files changed, 240 insertions(+), 16 deletions(-) create mode 100644 pycrostates/cluster/utils/__init__.py create mode 100644 pycrostates/cluster/utils/tests/test_utils.py create mode 100644 pycrostates/cluster/utils/utils.py diff --git a/.gitignore b/.gitignore index d1f0f45f..509e9877 100644 --- a/.gitignore +++ b/.gitignore @@ -137,3 +137,4 @@ dmypy.json # ---------------------------------------------------------------------------- .DS_Store +junit-results.xml diff --git a/docs/source/api/cluster.rst b/docs/source/api/cluster.rst index 17dc1660..bdf13202 100644 --- a/docs/source/api/cluster.rst +++ b/docs/source/api/cluster.rst @@ -14,3 +14,17 @@ Cluster ModKMeans AAHCluster + +Utils +----- + +.. currentmodule:: pycrostates.cluster.utils + +.. automodule:: pycrostates.cluster.utils + :no-members: + :no-inherited-members: + +.. autosummary:: + :toctree: generated/ + + optimize_order diff --git a/pycrostates/_typing.py b/pycrostates/_typing.py index 9dd3e9ac..79786b17 100644 --- a/pycrostates/_typing.py +++ b/pycrostates/_typing.py @@ -23,5 +23,11 @@ class CHInfo(ABC): pass +class Cluster(ABC): + """Typing for a clustering class.""" + + pass + + RANDomState = Optional[Union[int, RandomState, Generator]] Picks = Optional[Union[str, NDArray[int]]] diff --git a/pycrostates/cluster/__init__.py b/pycrostates/cluster/__init__.py index 712d5dc6..bd97d04f 100644 --- a/pycrostates/cluster/__init__.py +++ b/pycrostates/cluster/__init__.py @@ -14,7 +14,11 @@ :class:`~pycrostates.segmentation.EpochsSegmentation` depending on the dataset to segment.""" +from . import utils # noqa: F401 from .aahc import AAHCluster # noqa: F401 from .kmeans import ModKMeans # noqa: F401 -__all__ = ("ModKMeans", "AAHCluster") +__all__ = ( + "ModKMeans", + "AAHCluster", +) diff --git a/pycrostates/cluster/_base.py b/pycrostates/cluster/_base.py index bb2f3c37..4f440ce0 100644 --- a/pycrostates/cluster/_base.py +++ b/pycrostates/cluster/_base.py @@ -1,4 +1,4 @@ -from abc import ABC, abstractmethod +from abc import abstractmethod from copy import copy, deepcopy from itertools import groupby from pathlib import Path @@ -13,7 +13,7 @@ from numpy.typing import NDArray from scipy.signal import convolve2d -from .._typing import CHData, Picks +from .._typing import CHData, Cluster, Picks from ..segmentation import EpochsSegmentation, RawSegmentation from ..utils import _corr_vectors from ..utils._checks import ( @@ -27,9 +27,10 @@ from ..utils._logs import logger, verbose from ..utils.mixin import ChannelsMixin, ContainsMixin, MontageMixin from ..viz import plot_cluster_centers +from .utils import optimize_order -class _BaseCluster(ABC, ChannelsMixin, ContainsMixin, MontageMixin): +class _BaseCluster(Cluster, ChannelsMixin, ContainsMixin, MontageMixin): """Base Class for Microstates Clustering algorithms.""" @abstractmethod @@ -37,6 +38,7 @@ def __init__(self): self._n_clusters = None self._cluster_names = None self._cluster_centers_ = None + self._ignore_polarity = None # fit variables self._info = None @@ -363,8 +365,20 @@ def reorder_clusters( NDArray[int], ] ] = None, + template: Optional[Cluster] = None, ) -> None: - """Reorder the clusters. + """ + Reorder the clusters of the fitted model. + + Specify one of the following arguments to change the current order: + + * ``mapping``: a dictionary that maps old cluster positions + to new positions, + * ``order``: a 1D iterable containing the new order, + * ``template``: a fitted clustering algorithm used as a reference + to match the order. + + Only one argument can be set at a time. Parameters ---------- @@ -373,19 +387,27 @@ def reorder_clusters( key: old position, value: new position. order : list of int | tuple of int | array of int 1D iterable containing the new order. + Positions are 0-indexed. + template : :ref:`cluster` + Fitted clustering algorithm use as template for + ordering optimization. For more details about the + current implementation, check the + :func:`pycrostates.cluster.utils.optimize_order` + documentation. Notes ----- - The positions are 0-indexed. Operates in-place. """ self._check_fit() - if mapping is not None and order is not None: + if sum(x is not None for x in (mapping, order, template)) > 1: raise ValueError( - "Only one of 'mapping' or 'order' must be provided." + "Only one of 'mapping', 'order' or 'template' " + "must be provided." ) + # Mapping if mapping is not None: _check_type(mapping, (dict,), item_name="mapping") valids = tuple(range(self._n_clusters)) @@ -419,6 +441,7 @@ def reorder_clusters( # sanity-check assert len(set(order)) == self._n_clusters + # Order elif order is not None: _check_type(order, (list, tuple, np.ndarray), item_name="order") if isinstance(order, np.ndarray) and len(order.shape) != 1: @@ -436,9 +459,13 @@ def reorder_clusters( ) order = list(order) + # Cluster + elif template is not None: + order = optimize_order(self, template) + else: logger.warning( - "Either 'mapping' or 'order' should not be 'None' " + "Either 'mapping', 'order' or 'template' should not be 'None' " "for method 'reorder_clusters' to operate." ) return diff --git a/pycrostates/cluster/aahc.py b/pycrostates/cluster/aahc.py index d98ab39b..99c12d2b 100644 --- a/pycrostates/cluster/aahc.py +++ b/pycrostates/cluster/aahc.py @@ -43,7 +43,7 @@ def __init__( self._n_clusters = _BaseCluster._check_n_clusters(n_clusters) self._cluster_names = [str(k) for k in range(self.n_clusters)] - # TODO : ignor_polarity=True for now. + # TODO : ignore_polarity=True for now. # After _BaseCluster and Metric support ignore_polarity # make the parameter an argument # https://github.com/vferat/pycrostates/pull/93#issue-1431122168 diff --git a/pycrostates/cluster/kmeans.py b/pycrostates/cluster/kmeans.py index ed1ab059..26ab7f7b 100644 --- a/pycrostates/cluster/kmeans.py +++ b/pycrostates/cluster/kmeans.py @@ -236,6 +236,7 @@ def fit( self._cluster_centers_ = best_maps self._labels_ = best_segmentation self._fitted = True + self._ignore_polarity = True @copy_doc(_BaseCluster.save) def save(self, fname: Union[str, Path]): diff --git a/pycrostates/cluster/tests/test_aahc.py b/pycrostates/cluster/tests/test_aahc.py index 2be65595..35a2a0e3 100644 --- a/pycrostates/cluster/tests/test_aahc.py +++ b/pycrostates/cluster/tests/test_aahc.py @@ -528,9 +528,15 @@ def test_reorder(caplog): ) aahCluster_.reorder_clusters() - assert "Either 'mapping' or 'order' should not be 'None' " in caplog.text + assert ( + "Either 'mapping', 'order' or 'template' should not be 'None' " + in caplog.text + ) - with pytest.raises(ValueError, match="Only one of 'mapping' or 'order'"): + with pytest.raises( + ValueError, + match="Only one of 'mapping', 'order' or 'template' must be provided.", + ): aahCluster_.reorder_clusters(mapping={0: 1}, order=[1, 0, 2, 3]) # Test unfitted diff --git a/pycrostates/cluster/tests/test_kmeans.py b/pycrostates/cluster/tests/test_kmeans.py index 0e25dc33..7cc95be3 100644 --- a/pycrostates/cluster/tests/test_kmeans.py +++ b/pycrostates/cluster/tests/test_kmeans.py @@ -409,9 +409,15 @@ def test_reorder(caplog): ModK_.reorder_clusters(order=np.array([[0, 1, 2, 3], [0, 1, 2, 3]])) ModK_.reorder_clusters() - assert "Either 'mapping' or 'order' should not be 'None' " in caplog.text + assert ( + "Either 'mapping', 'order' or 'template' should not be 'None'" + in caplog.text + ) - with pytest.raises(ValueError, match="Only one of 'mapping' or 'order'"): + with pytest.raises( + ValueError, + match="Only one of 'mapping', 'order' or 'template' must be provided.", + ): ModK_.reorder_clusters(mapping={0: 1}, order=[1, 0, 2, 3]) # Test unfitted @@ -423,6 +429,13 @@ def test_reorder(caplog): with pytest.raises(RuntimeError, match="must be fitted before"): ModK_.reorder_clusters(order=[1, 0, 2, 3]) + # Test template + ModK_ = ModK.copy() + ModK__ = ModK_.copy() + ModK_.reorder_clusters(order=np.array([1, 0, 2, 3])) + ModK_.reorder_clusters(template=ModK__) + assert np.allclose(ModK_.cluster_centers_, ModK_.cluster_centers_) + def test_properties(caplog): """Test properties.""" @@ -746,7 +759,7 @@ def test_refit(): ModK_.fit(raw, picks="mag") mag_ch_names = ModK_.info["ch_names"] mag_cluster_centers = ModK_.cluster_centers_ - assert eeg_ch_names != mag_ch_names + assert not np.array_equal(eeg_ch_names, mag_ch_names) assert eeg_cluster_centers.shape != mag_cluster_centers.shape # invalid @@ -764,7 +777,7 @@ def test_refit(): with pytest.raises(RuntimeError, match="must be unfitted"): ModK_.fit(raw, picks="mag") # works assert eeg_ch_names == ModK_.info["ch_names"] - assert np.allclose(eeg_cluster_centers, ModK_.cluster_centers_) + assert np.isclose(eeg_cluster_centers, ModK_.cluster_centers_).all() def test_predict_default(caplog): diff --git a/pycrostates/cluster/utils/__init__.py b/pycrostates/cluster/utils/__init__.py new file mode 100644 index 00000000..d224258a --- /dev/null +++ b/pycrostates/cluster/utils/__init__.py @@ -0,0 +1,5 @@ +"""This module contains the clustering utils functions of ``pycrostates``.""" + +from .utils import optimize_order # noqa: F401 + +__all__ = ("optimize_order",) diff --git a/pycrostates/cluster/utils/tests/test_utils.py b/pycrostates/cluster/utils/tests/test_utils.py new file mode 100644 index 00000000..c85fa995 --- /dev/null +++ b/pycrostates/cluster/utils/tests/test_utils.py @@ -0,0 +1,76 @@ +"""Test cluster utils.""" + +import numpy as np +from mne.datasets import testing +from mne.io import read_raw_fif + +from pycrostates.cluster import ModKMeans +from pycrostates.cluster.utils.utils import _optimize_order, optimize_order + +directory = testing.data_path() / "MEG" / "sample" +fname = directory / "sample_audvis_trunc_raw.fif" +raw = read_raw_fif(fname, preload=False) +raw.pick("eeg").crop(0, 10) +raw.load_data() +# Fit one for general purposes +n_clusters = 5 +ModK_0 = ModKMeans( + n_clusters=n_clusters, n_init=10, max_iter=100, tol=1e-4, random_state=0 +) +ModK_0.fit(raw, n_jobs=1) + +ModK_1 = ModKMeans( + n_clusters=n_clusters, n_init=10, max_iter=100, tol=1e-4, random_state=1 +) +ModK_1.fit(raw, n_jobs=1) + + +def test__optimize_order(): + template = ModK_1._cluster_centers_ + n_states, n_electrodes = template.shape + # Shuffle template + arr = np.arange(n_states) + np.random.shuffle(arr) + random_template = template[arr] + # invert polarity + polarities = np.array([-1, 1, -1, 1, 1]) + random_pol_template = polarities[:, np.newaxis] * random_template + + # No suffle + current = template + ignore_polarity = True + order = _optimize_order(current, template, ignore_polarity=ignore_polarity) + assert np.all(order == np.arange(n_states)) + + # Shuffle + current = random_template + ignore_polarity = False + order = _optimize_order(current, template, ignore_polarity=ignore_polarity) + assert np.allclose(current[order], template) + + # Shuffle + ignore_polarity + current = random_template + ignore_polarity = True + order = _optimize_order(current, template, ignore_polarity=ignore_polarity) + assert np.allclose(current[order], template) + + # Shuffle + sign + ignore_polarity + current = random_pol_template + ignore_polarity = True + order_ = _optimize_order( + current, template, ignore_polarity=ignore_polarity + ) + assert np.all(order == order_) + + # Shuffle + sign + current = random_pol_template + ignore_polarity = False + order = _optimize_order(current, template, ignore_polarity=ignore_polarity) + corr = np.corrcoef(template, current[order])[n_states:, :n_states] + corr_order = np.corrcoef(template, current[order])[n_states:, :n_states] + assert np.trace(corr) <= np.trace(corr_order) + + +def test_optimize_order(): + order = optimize_order(ModK_0, ModK_1) + assert np.all(np.sort(np.unique(order)) == np.arange(len(order))) diff --git a/pycrostates/cluster/utils/utils.py b/pycrostates/cluster/utils/utils.py new file mode 100644 index 00000000..adc04d9f --- /dev/null +++ b/pycrostates/cluster/utils/utils.py @@ -0,0 +1,71 @@ +import numpy as np +import scipy +from numpy.typing import NDArray + +from ..._typing import Cluster +from ...utils._checks import _check_type +from ...utils._docs import fill_doc + + +def _optimize_order( + centers: NDArray[float], + template_centers: NDArray[float], + ignore_polarity: bool = True, +): + n_states = len(centers) + M = np.corrcoef(template_centers, centers)[:n_states, n_states:] + if ignore_polarity: + M = np.abs(M) + _, order = scipy.optimize.linear_sum_assignment(-M) + return order + + +@fill_doc +def optimize_order(inst: Cluster, template_inst: Cluster): + """Optimize the order of cluster centers between two cluster instances. + + Optimize the order of cluster centers in an instance of a clustering + algorithm to maximize auto-correlation, based on a template instance + as determined by the Hungarian algorithm. + The two cluster instances must have the same number of cluster centers + and the same polarity setting. + + Parameters + ---------- + inst : :ref:`cluster` + Fitted clustering algorithm to reorder. + template_inst : :ref:`cluster` + Fitted clustering algorithm to use as template for reordering. + + Returns + ------- + order : list of int + The new order to apply to inst to maximize auto-correlation + of cluster centers. + """ + from .._base import _BaseCluster + + _check_type(inst, (_BaseCluster,), item_name="inst") + inst._check_fit() + _check_type(template_inst, (_BaseCluster,), item_name="template_inst") + template_inst._check_fit() + + if inst.n_clusters != template_inst.n_clusters: + raise ValueError( + "Instance and the template must have the same " + "number of cluster centers." + ) + if inst._ignore_polarity != template_inst._ignore_polarity: + raise ValueError( + "Cannot find order: Instance was fitted using " + f"ignore_polarity={inst._ignore_polarity} while " + "template was fitted using ignore_polarity=" + f"{inst._ignore_polarity} which could lead to " + "misinterpretations." + ) + inst_centers = inst._cluster_centers_ + template_centers = template_inst._cluster_centers_ + order = _optimize_order( + inst_centers, template_centers, ignore_polarity=inst._ignore_polarity + ) + return order