Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

new cluster test API #12663

Draft
wants to merge 99 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
99 commits
Select commit Hold shift + click to select a range
62daaf0
added cluster test api, first commit
CarinaFo Jun 14, 2024
d59978f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 14, 2024
2843905
tested dataframe function and results, cleaned up
CarinaFo Jun 14, 2024
1985da3
Merge branch 'new_cluster_stats_api_GSOC24' of https://github.com/Car…
CarinaFo Jun 14, 2024
fa5b215
added ToDos
CarinaFo Jun 14, 2024
1a1511d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 14, 2024
0ea220c
Merge branch 'new_cluster_stats_api_GSOC24' of https://github.com/Car…
CarinaFo Jun 14, 2024
a12cf95
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 14, 2024
3c5d4f1
Merge branch 'mne-tools:main' into new_cluster_stats_api_GSOC24
CarinaFo Jun 19, 2024
45ce63a
added formula support and implemented suggestions
CarinaFo Jun 19, 2024
2b7bae8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 19, 2024
38834ba
fixed linting errors
CarinaFo Jun 22, 2024
c00859f
ENH: Add dataset [skip azp] [skip actions]
larsoner Jun 25, 2024
9c8ec90
FIX: One more [skip azp] [skip actions]
larsoner Jun 25, 2024
47363b5
FIX: Title [skip azp] [skip actions]
larsoner Jun 25, 2024
1f6221d
first draft of formulaic paired t-test
CarinaFo Jun 30, 2024
37616e5
first draft without cluster plotting class implemented
CarinaFo Jul 6, 2024
6aaef9a
cleaned up plotting function
CarinaFo Jul 6, 2024
0f99c70
implemented cluser results class
CarinaFo Jul 6, 2024
4083691
added contribution
CarinaFo Jul 6, 2024
42d70f9
Merge branch 'mne-tools:main' into new_cluster_stats_api_GSOC24
CarinaFo Jul 18, 2024
8a57215
Merge branch 'new_cluster_stats_api_GSOC24' of https://github.com/Car…
CarinaFo Jul 18, 2024
7e9b2e5
fixed codespell
CarinaFo Jul 18, 2024
12f27dd
Merge branch 'main' into new_cluster_stats_api_GSOC24
CarinaFo Jul 19, 2024
8f510a9
first review
CarinaFo Jul 22, 2024
d6c0c4c
quick clean up
CarinaFo Jul 22, 2024
f17f38f
test compare_old_vs_new_cluster_API
CarinaFo Jul 22, 2024
9d592de
simplify tests
drammock Jul 25, 2024
d64ef84
refactor cluster_test
drammock Jul 25, 2024
dc8a799
make tutorial match modified API
drammock Jul 25, 2024
f12cf6e
remove unused test helper func
drammock Jul 25, 2024
5b97971
vulture allowlist update
drammock Jul 25, 2024
5f5b0fc
included BaseTFR in validate_cluster_df
CarinaFo Jul 28, 2024
ccccb5b
comments on cluster_test function
CarinaFo Jul 28, 2024
59b1a3a
updated clusterResult class and plot function
CarinaFo Jul 28, 2024
98d0879
updated function call for plotting
CarinaFo Jul 28, 2024
ec03242
changed color
CarinaFo Jul 28, 2024
a513518
Merge branch 'main' into new_cluster_stats_api_GSOC24
CarinaFo Jul 31, 2024
1a5dd9d
Merge branch 'new_cluster_stats_api_GSOC24' of https://github.com/Car…
CarinaFo Jul 31, 2024
5941f61
docstring/docdict cleanups and fixes
drammock Aug 1, 2024
368fa44
implemented Dan's comments
CarinaFo Aug 5, 2024
3aa32b6
implemented Dan's comments
CarinaFo Aug 5, 2024
a76afd3
test for handling different MNE objects - test is failing
CarinaFo Aug 5, 2024
b5fce8b
adjusted test to account for multiple subjects
CarinaFo Aug 6, 2024
3ce510c
refactor df validation to return bools
drammock Aug 10, 2024
feb1911
unrelated typing fix
drammock Aug 10, 2024
6f97811
rework test
drammock Aug 10, 2024
b09d20a
minor cleanup
drammock Aug 12, 2024
977e153
fix imports
drammock Aug 12, 2024
a288d85
use MRO in test too
drammock Aug 12, 2024
81ce0d0
Merge pull request #5 from drammock/new_cluster_stats_api_GSOC24
CarinaFo Aug 21, 2024
05586c8
added cluster test api, first commit
CarinaFo Jun 14, 2024
e8770fd
tested dataframe function and results, cleaned up
CarinaFo Jun 14, 2024
a081d7d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 14, 2024
d6d70c8
added ToDos
CarinaFo Jun 14, 2024
8345261
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 14, 2024
0373195
added formula support and implemented suggestions
CarinaFo Jun 19, 2024
8bc44f9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 19, 2024
654a350
fixed linting errors
CarinaFo Jun 22, 2024
d1ed8a1
ENH: Add dataset [skip azp] [skip actions]
larsoner Jun 25, 2024
c634a44
FIX: One more [skip azp] [skip actions]
larsoner Jun 25, 2024
0c2eb4f
FIX: Title [skip azp] [skip actions]
larsoner Jun 25, 2024
f46a79c
first draft of formulaic paired t-test
CarinaFo Jun 30, 2024
5d1cbae
first draft without cluster plotting class implemented
CarinaFo Jul 6, 2024
268d0cf
cleaned up plotting function
CarinaFo Jul 6, 2024
2f722bd
implemented cluser results class
CarinaFo Jul 6, 2024
fb75cfd
fixed codespell
CarinaFo Jul 18, 2024
a87ffed
first review
CarinaFo Jul 22, 2024
1f857ad
quick clean up
CarinaFo Jul 22, 2024
450738b
test compare_old_vs_new_cluster_API
CarinaFo Jul 22, 2024
d41efbe
simplify tests
drammock Jul 25, 2024
9523fae
refactor cluster_test
drammock Jul 25, 2024
9661492
make tutorial match modified API
drammock Jul 25, 2024
cac0559
remove unused test helper func
drammock Jul 25, 2024
47ac838
vulture allowlist update
drammock Jul 25, 2024
033c158
included BaseTFR in validate_cluster_df
CarinaFo Jul 28, 2024
2c2f341
comments on cluster_test function
CarinaFo Jul 28, 2024
e9b5fa2
updated clusterResult class and plot function
CarinaFo Jul 28, 2024
2fd17d3
updated function call for plotting
CarinaFo Jul 28, 2024
150c530
changed color
CarinaFo Jul 28, 2024
3cc9e2c
docstring/docdict cleanups and fixes
drammock Aug 1, 2024
2c27a69
implemented Dan's comments
CarinaFo Aug 5, 2024
2664ee2
implemented Dan's comments
CarinaFo Aug 5, 2024
4927544
test for handling different MNE objects - test is failing
CarinaFo Aug 5, 2024
006acdf
adjusted test to account for multiple subjects
CarinaFo Aug 6, 2024
f0f4cba
refactor df validation to return bools
drammock Aug 10, 2024
346e3ce
unrelated typing fix
drammock Aug 10, 2024
a49d2cd
rework test
drammock Aug 10, 2024
a01182b
minor cleanup
drammock Aug 12, 2024
0984b61
fix imports
drammock Aug 12, 2024
6322499
use MRO in test too
drammock Aug 12, 2024
a04b8a3
fix vulture allowlist
drammock Aug 22, 2024
f1d39bf
fix nesting and type hints
drammock Aug 22, 2024
987ea43
strict=False
drammock Aug 22, 2024
78829b4
nest import in test file too
drammock Aug 22, 2024
eb98849
Merge branch 'new_cluster_stats_api_GSOC24' of https://github.com/Car…
CarinaFo Aug 23, 2024
ac943e3
Merge branch 'main' into new_cluster_stats_api_GSOC24
CarinaFo Sep 17, 2024
372bcca
clean up pyproject mess
CarinaFo Oct 2, 2024
4da8463
add n_permutations, plotting, added min_cluster_p_value
CarinaFo Oct 2, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,4 @@ dependencies:
- lazy_loader
- defusedxml
- python-neo
- formulaic
4 changes: 2 additions & 2 deletions mne/datasets/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@
# here: ↓↓↓↓↓↓↓↓
RELEASES = dict(
testing="0.152",
misc="0.27",
misc="0.30",
phantom_kit="0.2",
ucl_opm_auditory="0.2",
)
Expand Down Expand Up @@ -131,7 +131,7 @@
)
MNE_DATASETS["misc"] = dict(
archive_name=f"{MISC_VERSIONED}.tar.gz", # 'mne-misc-data',
hash="md5:e343d3a00cb49f8a2f719d14f4758afe",
hash="md5:201d35531d3c03701cf50e38bb73481f",
url=(
"https://codeload.github.com/mne-tools/mne-misc-data/tar.gz/"
f'{RELEASES["misc"]}'
Expand Down
315 changes: 315 additions & 0 deletions mne/stats/cluster_level.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,22 @@
# Eric Larson <[email protected]>
# Denis Engemann <[email protected]>
# Fernando Perez (bin_perm_rep function)
# Carina Forster <[email protected]>
#
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from mpl_toolkits.axes_grid1 import make_axes_locatable
from scipy import ndimage, sparse
from scipy.sparse.csgraph import connected_components
from scipy.stats import f as fstat
from scipy.stats import t as tstat

from .. import EvokedArray
from ..channels import find_ch_adjacency
from ..fixes import has_numba, jit
from ..parallel import parallel_func
from ..source_estimate import MixedSourceEstimate, SourceEstimate, VolSourceEstimate
Expand All @@ -24,13 +30,15 @@
ProgressBar,
_check_option,
_pl,
_soft_import,
_validate_type,
check_random_state,
logger,
split_list,
verbose,
warn,
)
from ..viz import plot_compare_evokeds
from .parametric import f_oneway, ttest_1samp_no_p


Expand Down Expand Up @@ -1729,3 +1737,310 @@ def summarize_clusters_stc(
data_summary[:, 0] = np.sum(data_summary, axis=1)

return klass(data_summary, vertices, tmin, tstep, subject)


def cluster_test(
df: pd.DataFrame,
formula: str = None, # Wilkinson notation formula for design matrix
n_permutations: int = 10000,
seed: None | int | np.random.RandomState = None,
tail: int = 0, # 0 for two-tailed, 1 for greater, -1 for less
CarinaFo marked this conversation as resolved.
Show resolved Hide resolved
n_jobs: int = 1, # how many cores to use
adjacency: tuple = None,
CarinaFo marked this conversation as resolved.
Show resolved Hide resolved
max_step: int = 1, # maximum distance between samples (time points)
exclude: list = None, # exclude no time points or channels
CarinaFo marked this conversation as resolved.
Show resolved Hide resolved
step_down_p: int = 0, # step down in jumps test
t_power: int = 1, # weigh each location by its stats score
out_type: str = "indices",
CarinaFo marked this conversation as resolved.
Show resolved Hide resolved
check_disjoint: bool = False,
buffer_size: int = None, # block size for chunking the data
CarinaFo marked this conversation as resolved.
Show resolved Hide resolved
):
"""
Run a cluster permutation test based on formulaic input.

# currently only supports paired t-test on evokeds or epochs

Parameters
----------
dataframe : pd.DataFrame
Dataframe with evoked/epoched data, conditions and subject IDs.
formula : str, optional
Wilkinson notation formula for design matrix. Default is None.
CarinaFo marked this conversation as resolved.
Show resolved Hide resolved
n_permutations : int, optional
Number of permutations. Default is 10000.
seed : None | int | np.random.RandomState, optional
Seed for the random number generator. Default is None.
tail : int, optional
0 for two-tailed, 1 for greater, -1 for less. Default is 0.
n_jobs : int, optional
How many cores to use. Default is 1.
adjacency : None, optional
Adjacency matrix. Default is None.
max_step : int, optional
Maximum distance between samples (time points). Default is 1.
exclude : np.Array, optional
Exclude no time points or channels. Default is None.
step_down_p : int, optional
Step down in jumps test. Default is 0.
t_power : int, optional
Weigh each location by its stats score. Default is 1.
out_type : str, optional
Output type. Default is "indices".
check_disjoint : bool, optional
Check if clusters are disjoint. Default is False.
buffer_size : int, optional
Block size for chunking the data. Default is None.
seed : int, optional
Seed for the random number generator. Default is None.

Returns
-------
ClusterResult
Object containing the results of the cluster permutation test.
"""
# for now this assumes a dataframe with a column for evoked data or epochs
# add a data column to the dataframe (numpy array)
df["data"] = [evoked.data for evoked in df.evoked]
CarinaFo marked this conversation as resolved.
Show resolved Hide resolved

# extract number of channels and timepoints
# (eventually should also allow for frequency)
n_channels, n_timepoints = df["data"][0].shape
CarinaFo marked this conversation as resolved.
Show resolved Hide resolved

# convert wide format to long format for formulaic
df_long = unpack_time_and_channels(df)
CarinaFo marked this conversation as resolved.
Show resolved Hide resolved

# pivot the DataFrame
pivot_df = df_long.pivot_table(
index=["subject_index", "channel", "timepoint"],
columns="condition",
values="value",
).reset_index()

# if not 2 unique conditions raise error
if len(pd.unique(df.condition)) != 2:
raise ValueError("Condition list needs to contain 2 unique values")

# Get unique elements and the indices of their first occurrences
unique_elements, indices = np.unique(df.condition, return_index=True)

# Sort unique elements by the indices of their first occurrences
conditions = unique_elements[np.argsort(indices)]

# print the contrast used for the paired t-test
print(f"Contrast used for paired t-test: {conditions[0]} - {conditions[1]}")
CarinaFo marked this conversation as resolved.
Show resolved Hide resolved

# Compute the difference (assuming there are only 2 conditions)
pivot_df["evoked"] = pivot_df[conditions[0]] - pivot_df[conditions[1]]

# Optional: Clean up the DataFrame
pivot_df = pivot_df[["subject_index", "channel", "timepoint", "evoked"]]

# check if formula is present
if formula is not None:
formulaic = _soft_import(
"formulaic", purpose="set up Design Matrix"
) # soft import (not a dependency for MNE)

# for the paired t-test y is the difference between conditions
# X is the design matrix with a column with 1s and 0s for each participant
# Create the design matrix using formulaic
y, X = formulaic.model_matrix(formula, pivot_df)
else:
raise ValueError(
"Formula is required and needs to be a string in Wilkinson notation."
)

# now prep design matrix for input into MNE cluster function
# cluster functions expects channels as list dimension
y_for_cluster = y.values.reshape(-1, n_channels, n_timepoints).transpose(0, 2, 1)

adjacency, _ = find_ch_adjacency(df["evoked"][0].info, ch_type="eeg")

# define stat function and threshold
stat_fun, threshold = _check_fun(
X=y_for_cluster, stat_fun=None, threshold=None, tail=0, kind="within"
)

# Run the cluster-based permutation test
T_obs, clusters, cluster_p_values, H0 = _permutation_cluster_test(
[y_for_cluster],
n_permutations=n_permutations,
threshold=threshold,
stat_fun=stat_fun,
tail=tail,
n_jobs=n_jobs,
adjacency=adjacency,
max_step=max_step, # maximum distance between samples (time points)
exclude=exclude, # exclude no time points or channels
step_down_p=step_down_p, # step down in jumps test
t_power=t_power, # weigh each location by its stats score
out_type=out_type,
check_disjoint=check_disjoint,
buffer_size=buffer_size, # block size for chunking the data
seed=seed,
)

print(f"smallest cluster p-value: {min(cluster_p_values)}")

return ClusterResult(T_obs, clusters, cluster_p_values, H0)


def unpack_time_and_channels(df: pd.DataFrame = None) -> pd.DataFrame:
CarinaFo marked this conversation as resolved.
Show resolved Hide resolved
"""
Extract timepoints and channels and convert to long.

Parameters
----------
df : pd.DataFrame
DataFrame in wide format.

Returns
-------
df_long : pd.DataFrame
DataFrame in long format.
"""
# Extracting all necessary data using list comprehensions for better performance
long_format_data = [
{
"condition": row["condition"],
"subject_index": row["subject_index"],
"channel": channel,
"timepoint": timepoint,
"value": row["data"][channel, timepoint],
}
for idx, row in df.iterrows()
for channel in range(row["data"].shape[0])
for timepoint in range(row["data"].shape[1])
]
drammock marked this conversation as resolved.
Show resolved Hide resolved

# Creating the long format DataFrame
df_long = pd.DataFrame(long_format_data)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the approach of first making a list of dictionaries, and then creating a dataframe from that, will (I think?) involve an extra copy, which means increased memory usage. So do if possible try to find a way to use pd.DataFrame.explode or a simlar approach, to go from wide DF to long DF in one step.


return df_long


class ClusterResult:
"""
Object containing the results of the cluster permutation test.

Parameters
----------
T_obs : np.ndarray
The observed test statistic.
clusters : list
List of clusters.
cluster_p_values : np.ndarray
P-values for each cluster.
H0 : np.ndarray
Max cluster level stats observed under permutation.
"""

def __init__(self, T_obs, clusters, cluster_p_values, H0):
CarinaFo marked this conversation as resolved.
Show resolved Hide resolved
self.T_obs = T_obs
self.clusters = clusters
self.cluster_p_values = cluster_p_values
self.H0 = H0

def plot_cluster(self, cond_dict: dict = None):
CarinaFo marked this conversation as resolved.
Show resolved Hide resolved
"""
Plot the cluster with the lowest p-value.

2D cluster plotted with topoplot on the left and evoked signals on the right.
Timepoints that are part of the cluster are
highlighted in green on the evoked signals.

Parameters
----------
cond_dict : dict
Dictionary with condition labels as keys and evoked objects as values.

Returns
-------
None
CarinaFo marked this conversation as resolved.
Show resolved Hide resolved

"""
# extract condition labels from the dictionary
cond_keys = list(cond_dict.keys())
# extract the evokeds from the dictionary
cond_values = list(cond_dict.values())

# configure variables for visualization
colors = {cond_keys[0]: "crimson", cond_keys[1]: "steelblue"}

lowest_p_cluster = np.argmin(self.cluster_p_values)

# plot the cluster with the lowest p-value
time_inds, space_inds = np.squeeze(self.clusters[lowest_p_cluster])
ch_inds = np.unique(space_inds)
time_inds = np.unique(time_inds)

# get topography for t stat
t_map = self.T_obs[time_inds, ...].mean(axis=0).astype(int)

# get signals at the sensors contributing to the cluster
sig_times = cond_values[0][0].times[time_inds]

# create spatial mask
mask = np.zeros((t_map.shape[0], 1), dtype=bool)
mask[ch_inds, :] = True

# initialize figure
fig, ax_topo = plt.subplots(1, 1, figsize=(10, 3), layout="constrained")

# plot average test statistic and mark significant sensors
t_evoked = EvokedArray(t_map[:, np.newaxis], cond_values[0][0].info, tmin=0)
t_evoked.plot_topomap(
times=0,
mask=mask,
axes=ax_topo,
cmap="RdBu_r",
show=False,
colorbar=False,
mask_params=dict(markersize=10),
scalings=1.00,
)
image = ax_topo.images[0]

# remove the title that would otherwise say "0.000 s"
ax_topo.set_title("")

# soft import?
# make_axes_locatable = _soft_import(
# "mpl_toolkits.axes_grid1.make_axes_locatable",
# purpose="plot cluster results"
# ) # soft import (not a dependency for MNE)

# create additional axes (for ERF and colorbar)
divider = make_axes_locatable(ax_topo)

# add axes for colorbar
ax_colorbar = divider.append_axes("right", size="5%", pad=0.1)
cbar = plt.colorbar(image, cax=ax_colorbar)
cbar.set_label("t-value")
ax_topo.set_xlabel(
"average from {:0.3f} to {:0.3f} s".format(*sig_times[[0, -1]])
)

# add new axis for time courses and plot time courses
ax_signals = divider.append_axes("right", size="300%", pad=1.3)
title = f"Signal averaged over {len(ch_inds)} sensor(s)"
plot_compare_evokeds(
cond_dict,
title=title,
picks=ch_inds,
axes=ax_signals,
colors=colors,
show=False,
split_legend=True,
truncate_yaxis="auto",
truncate_xaxis=False,
)
plt.legend(frameon=False, loc="upper left")

# plot temporal cluster extent
ymin, ymax = ax_signals.get_ylim()
ax_signals.fill_betweenx(
(ymin, ymax), sig_times[0], sig_times[-1], color="grey", alpha=0.3
)

plt.show()
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ full-no-qt = [
"snirf",
"defusedxml",
"neo",
"formulaic",
]
full = ["mne[full-no-qt]", "PyQt6!=6.6.0", "PyQt6-Qt6!=6.6.0,!=6.7.0"]
full-pyqt6 = ["mne[full]"]
Expand Down Expand Up @@ -145,6 +146,7 @@ test_extra = [
"snirf",
"neo",
"mne-bids",
"formulaic",
]

# Dependencies for building the documentation
Expand All @@ -157,6 +159,7 @@ doc = [
"sphinxcontrib-towncrier",
"memory_profiler",
"neo",
"formulaic",
"seaborn!=0.11.2",
"sphinx_copybutton",
"sphinx-design",
Expand Down
Loading
Loading