Skip to content

Commit

Permalink
Auto-label/order microstates maps (#105)
Browse files Browse the repository at this point in the history
* Add optimize_order

* Fix style

* Fix style

* fix tests

* Fix style

* Apply suggestions from code review

Co-authored-by: Mathieu Scheltienne <[email protected]>

* Add template option to base.reorder

* Fix docstring

* Fix docstrings

* Update .gitignore

* Update pycrostates/cluster/utils/utils.py

* Update pycrostates/cluster/_base.py

* Update pycrostates/cluster/_base.py

* Update pycrostates/cluster/_base.py

* add type-hint for Cluster class

* fix typo

---------

Co-authored-by: Mathieu Scheltienne <[email protected]>
Co-authored-by: Mathieu Scheltienne <[email protected]>
  • Loading branch information
3 people authored Apr 13, 2023
1 parent 3dd164e commit 91df686
Show file tree
Hide file tree
Showing 12 changed files with 240 additions and 16 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -137,3 +137,4 @@ dmypy.json

# ----------------------------------------------------------------------------
.DS_Store
junit-results.xml
14 changes: 14 additions & 0 deletions docs/source/api/cluster.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,17 @@ Cluster

ModKMeans
AAHCluster

Utils
-----

.. currentmodule:: pycrostates.cluster.utils

.. automodule:: pycrostates.cluster.utils
:no-members:
:no-inherited-members:

.. autosummary::
:toctree: generated/

optimize_order
6 changes: 6 additions & 0 deletions pycrostates/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,11 @@ class CHInfo(ABC):
pass


class Cluster(ABC):
"""Typing for a clustering class."""

pass


RANDomState = Optional[Union[int, RandomState, Generator]]
Picks = Optional[Union[str, NDArray[int]]]
6 changes: 5 additions & 1 deletion pycrostates/cluster/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@
:class:`~pycrostates.segmentation.EpochsSegmentation` depending on the dataset
to segment."""

from . import utils # noqa: F401
from .aahc import AAHCluster # noqa: F401
from .kmeans import ModKMeans # noqa: F401

__all__ = ("ModKMeans", "AAHCluster")
__all__ = (
"ModKMeans",
"AAHCluster",
)
43 changes: 35 additions & 8 deletions pycrostates/cluster/_base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from abc import ABC, abstractmethod
from abc import abstractmethod
from copy import copy, deepcopy
from itertools import groupby
from pathlib import Path
Expand All @@ -13,7 +13,7 @@
from numpy.typing import NDArray
from scipy.signal import convolve2d

from .._typing import CHData, Picks
from .._typing import CHData, Cluster, Picks
from ..segmentation import EpochsSegmentation, RawSegmentation
from ..utils import _corr_vectors
from ..utils._checks import (
Expand All @@ -27,16 +27,18 @@
from ..utils._logs import logger, verbose
from ..utils.mixin import ChannelsMixin, ContainsMixin, MontageMixin
from ..viz import plot_cluster_centers
from .utils import optimize_order


class _BaseCluster(ABC, ChannelsMixin, ContainsMixin, MontageMixin):
class _BaseCluster(Cluster, ChannelsMixin, ContainsMixin, MontageMixin):
"""Base Class for Microstates Clustering algorithms."""

@abstractmethod
def __init__(self):
self._n_clusters = None
self._cluster_names = None
self._cluster_centers_ = None
self._ignore_polarity = None

# fit variables
self._info = None
Expand Down Expand Up @@ -363,8 +365,20 @@ def reorder_clusters(
NDArray[int],
]
] = None,
template: Optional[Cluster] = None,
) -> None:
"""Reorder the clusters.
"""
Reorder the clusters of the fitted model.
Specify one of the following arguments to change the current order:
* ``mapping``: a dictionary that maps old cluster positions
to new positions,
* ``order``: a 1D iterable containing the new order,
* ``template``: a fitted clustering algorithm used as a reference
to match the order.
Only one argument can be set at a time.
Parameters
----------
Expand All @@ -373,19 +387,27 @@ def reorder_clusters(
key: old position, value: new position.
order : list of int | tuple of int | array of int
1D iterable containing the new order.
Positions are 0-indexed.
template : :ref:`cluster`
Fitted clustering algorithm use as template for
ordering optimization. For more details about the
current implementation, check the
:func:`pycrostates.cluster.utils.optimize_order`
documentation.
Notes
-----
The positions are 0-indexed.
Operates in-place.
"""
self._check_fit()

if mapping is not None and order is not None:
if sum(x is not None for x in (mapping, order, template)) > 1:
raise ValueError(
"Only one of 'mapping' or 'order' must be provided."
"Only one of 'mapping', 'order' or 'template' "
"must be provided."
)

# Mapping
if mapping is not None:
_check_type(mapping, (dict,), item_name="mapping")
valids = tuple(range(self._n_clusters))
Expand Down Expand Up @@ -419,6 +441,7 @@ def reorder_clusters(
# sanity-check
assert len(set(order)) == self._n_clusters

# Order
elif order is not None:
_check_type(order, (list, tuple, np.ndarray), item_name="order")
if isinstance(order, np.ndarray) and len(order.shape) != 1:
Expand All @@ -436,9 +459,13 @@ def reorder_clusters(
)
order = list(order)

# Cluster
elif template is not None:
order = optimize_order(self, template)

else:
logger.warning(
"Either 'mapping' or 'order' should not be 'None' "
"Either 'mapping', 'order' or 'template' should not be 'None' "
"for method 'reorder_clusters' to operate."
)
return
Expand Down
2 changes: 1 addition & 1 deletion pycrostates/cluster/aahc.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(
self._n_clusters = _BaseCluster._check_n_clusters(n_clusters)
self._cluster_names = [str(k) for k in range(self.n_clusters)]

# TODO : ignor_polarity=True for now.
# TODO : ignore_polarity=True for now.
# After _BaseCluster and Metric support ignore_polarity
# make the parameter an argument
# https://github.com/vferat/pycrostates/pull/93#issue-1431122168
Expand Down
1 change: 1 addition & 0 deletions pycrostates/cluster/kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ def fit(
self._cluster_centers_ = best_maps
self._labels_ = best_segmentation
self._fitted = True
self._ignore_polarity = True

@copy_doc(_BaseCluster.save)
def save(self, fname: Union[str, Path]):
Expand Down
10 changes: 8 additions & 2 deletions pycrostates/cluster/tests/test_aahc.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,9 +528,15 @@ def test_reorder(caplog):
)

aahCluster_.reorder_clusters()
assert "Either 'mapping' or 'order' should not be 'None' " in caplog.text
assert (
"Either 'mapping', 'order' or 'template' should not be 'None' "
in caplog.text
)

with pytest.raises(ValueError, match="Only one of 'mapping' or 'order'"):
with pytest.raises(
ValueError,
match="Only one of 'mapping', 'order' or 'template' must be provided.",
):
aahCluster_.reorder_clusters(mapping={0: 1}, order=[1, 0, 2, 3])

# Test unfitted
Expand Down
21 changes: 17 additions & 4 deletions pycrostates/cluster/tests/test_kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,9 +409,15 @@ def test_reorder(caplog):
ModK_.reorder_clusters(order=np.array([[0, 1, 2, 3], [0, 1, 2, 3]]))

ModK_.reorder_clusters()
assert "Either 'mapping' or 'order' should not be 'None' " in caplog.text
assert (
"Either 'mapping', 'order' or 'template' should not be 'None'"
in caplog.text
)

with pytest.raises(ValueError, match="Only one of 'mapping' or 'order'"):
with pytest.raises(
ValueError,
match="Only one of 'mapping', 'order' or 'template' must be provided.",
):
ModK_.reorder_clusters(mapping={0: 1}, order=[1, 0, 2, 3])

# Test unfitted
Expand All @@ -423,6 +429,13 @@ def test_reorder(caplog):
with pytest.raises(RuntimeError, match="must be fitted before"):
ModK_.reorder_clusters(order=[1, 0, 2, 3])

# Test template
ModK_ = ModK.copy()
ModK__ = ModK_.copy()
ModK_.reorder_clusters(order=np.array([1, 0, 2, 3]))
ModK_.reorder_clusters(template=ModK__)
assert np.allclose(ModK_.cluster_centers_, ModK_.cluster_centers_)


def test_properties(caplog):
"""Test properties."""
Expand Down Expand Up @@ -746,7 +759,7 @@ def test_refit():
ModK_.fit(raw, picks="mag")
mag_ch_names = ModK_.info["ch_names"]
mag_cluster_centers = ModK_.cluster_centers_
assert eeg_ch_names != mag_ch_names
assert not np.array_equal(eeg_ch_names, mag_ch_names)
assert eeg_cluster_centers.shape != mag_cluster_centers.shape

# invalid
Expand All @@ -764,7 +777,7 @@ def test_refit():
with pytest.raises(RuntimeError, match="must be unfitted"):
ModK_.fit(raw, picks="mag") # works
assert eeg_ch_names == ModK_.info["ch_names"]
assert np.allclose(eeg_cluster_centers, ModK_.cluster_centers_)
assert np.isclose(eeg_cluster_centers, ModK_.cluster_centers_).all()


def test_predict_default(caplog):
Expand Down
5 changes: 5 additions & 0 deletions pycrostates/cluster/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""This module contains the clustering utils functions of ``pycrostates``."""

from .utils import optimize_order # noqa: F401

__all__ = ("optimize_order",)
76 changes: 76 additions & 0 deletions pycrostates/cluster/utils/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
"""Test cluster utils."""

import numpy as np
from mne.datasets import testing
from mne.io import read_raw_fif

from pycrostates.cluster import ModKMeans
from pycrostates.cluster.utils.utils import _optimize_order, optimize_order

directory = testing.data_path() / "MEG" / "sample"
fname = directory / "sample_audvis_trunc_raw.fif"
raw = read_raw_fif(fname, preload=False)
raw.pick("eeg").crop(0, 10)
raw.load_data()
# Fit one for general purposes
n_clusters = 5
ModK_0 = ModKMeans(
n_clusters=n_clusters, n_init=10, max_iter=100, tol=1e-4, random_state=0
)
ModK_0.fit(raw, n_jobs=1)

ModK_1 = ModKMeans(
n_clusters=n_clusters, n_init=10, max_iter=100, tol=1e-4, random_state=1
)
ModK_1.fit(raw, n_jobs=1)


def test__optimize_order():
template = ModK_1._cluster_centers_
n_states, n_electrodes = template.shape
# Shuffle template
arr = np.arange(n_states)
np.random.shuffle(arr)
random_template = template[arr]
# invert polarity
polarities = np.array([-1, 1, -1, 1, 1])
random_pol_template = polarities[:, np.newaxis] * random_template

# No suffle
current = template
ignore_polarity = True
order = _optimize_order(current, template, ignore_polarity=ignore_polarity)
assert np.all(order == np.arange(n_states))

# Shuffle
current = random_template
ignore_polarity = False
order = _optimize_order(current, template, ignore_polarity=ignore_polarity)
assert np.allclose(current[order], template)

# Shuffle + ignore_polarity
current = random_template
ignore_polarity = True
order = _optimize_order(current, template, ignore_polarity=ignore_polarity)
assert np.allclose(current[order], template)

# Shuffle + sign + ignore_polarity
current = random_pol_template
ignore_polarity = True
order_ = _optimize_order(
current, template, ignore_polarity=ignore_polarity
)
assert np.all(order == order_)

# Shuffle + sign
current = random_pol_template
ignore_polarity = False
order = _optimize_order(current, template, ignore_polarity=ignore_polarity)
corr = np.corrcoef(template, current[order])[n_states:, :n_states]
corr_order = np.corrcoef(template, current[order])[n_states:, :n_states]
assert np.trace(corr) <= np.trace(corr_order)


def test_optimize_order():
order = optimize_order(ModK_0, ModK_1)
assert np.all(np.sort(np.unique(order)) == np.arange(len(order)))
71 changes: 71 additions & 0 deletions pycrostates/cluster/utils/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import numpy as np
import scipy
from numpy.typing import NDArray

from ..._typing import Cluster
from ...utils._checks import _check_type
from ...utils._docs import fill_doc


def _optimize_order(
centers: NDArray[float],
template_centers: NDArray[float],
ignore_polarity: bool = True,
):
n_states = len(centers)
M = np.corrcoef(template_centers, centers)[:n_states, n_states:]
if ignore_polarity:
M = np.abs(M)
_, order = scipy.optimize.linear_sum_assignment(-M)
return order


@fill_doc
def optimize_order(inst: Cluster, template_inst: Cluster):
"""Optimize the order of cluster centers between two cluster instances.
Optimize the order of cluster centers in an instance of a clustering
algorithm to maximize auto-correlation, based on a template instance
as determined by the Hungarian algorithm.
The two cluster instances must have the same number of cluster centers
and the same polarity setting.
Parameters
----------
inst : :ref:`cluster`
Fitted clustering algorithm to reorder.
template_inst : :ref:`cluster`
Fitted clustering algorithm to use as template for reordering.
Returns
-------
order : list of int
The new order to apply to inst to maximize auto-correlation
of cluster centers.
"""
from .._base import _BaseCluster

_check_type(inst, (_BaseCluster,), item_name="inst")
inst._check_fit()
_check_type(template_inst, (_BaseCluster,), item_name="template_inst")
template_inst._check_fit()

if inst.n_clusters != template_inst.n_clusters:
raise ValueError(
"Instance and the template must have the same "
"number of cluster centers."
)
if inst._ignore_polarity != template_inst._ignore_polarity:
raise ValueError(
"Cannot find order: Instance was fitted using "
f"ignore_polarity={inst._ignore_polarity} while "
"template was fitted using ignore_polarity="
f"{inst._ignore_polarity} which could lead to "
"misinterpretations."
)
inst_centers = inst._cluster_centers_
template_centers = template_inst._cluster_centers_
order = _optimize_order(
inst_centers, template_centers, ignore_polarity=inst._ignore_polarity
)
return order

0 comments on commit 91df686

Please sign in to comment.