Skip to content

Commit

Permalink
Fix EpochsTFRArray default drop log initialisation (#13028)
Browse files Browse the repository at this point in the history
  • Loading branch information
tsbinns authored Dec 21, 2024
1 parent ee2d0ca commit d814954
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 13 deletions.
1 change: 1 addition & 0 deletions doc/changes/devel/13028.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix epoch indexing in :class:`mne.time_frequency.EpochsTFRArray` when initialising the class with the default ``drop_log`` parameter, by `Thomas Binns`_.
28 changes: 16 additions & 12 deletions mne/time_frequency/tests/test_tfr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1218,25 +1218,29 @@ def test_averaging_freqsandtimes_epochsTFR():
avgpower = power.average(method=lambda x: np.mean(x, axis=2), **kwargs)


@pytest.mark.parametrize("n_drop", (0, 2))
def test_epochstfr_getitem(epochs_full, n_drop):
@pytest.mark.parametrize("n_drop, as_tfr_array", ((0, False), (0, True), (2, False)))
def test_epochstfr_getitem(epochs_full, n_drop, as_tfr_array):
"""Test EpochsTFR.__getitem__()."""
pd = pytest.importorskip("pandas")
from pandas.testing import assert_frame_equal

epochs_full.metadata = pd.DataFrame(dict(foo=list("aaaabbb"), bar=np.arange(7)))
epochs_full.drop(np.arange(n_drop))
tfr = epochs_full.compute_tfr(method="morlet", freqs=freqs_linspace)
# check that various attributes are preserved
assert_frame_equal(tfr.metadata, epochs_full.metadata)
assert epochs_full.drop_log == tfr.drop_log
for attr in ("events", "selection", "times"):
assert_array_equal(getattr(epochs_full, attr), getattr(tfr, attr))
# test pandas query
foo_a = tfr["foo == 'a'"]
bar_3 = tfr["bar <= 3"]
assert foo_a == bar_3
assert foo_a.shape[0] == 4 - n_drop
if not as_tfr_array: # check that various attributes are preserved
assert_frame_equal(tfr.metadata, epochs_full.metadata)
assert epochs_full.drop_log == tfr.drop_log
for attr in ("events", "selection", "times"):
assert_array_equal(getattr(epochs_full, attr), getattr(tfr, attr))
# test pandas query
foo_a = tfr["foo == 'a'"]
bar_3 = tfr["bar <= 3"]
assert foo_a == bar_3
assert foo_a.shape[0] == 4 - n_drop
else: # repackage to check __getitem__ also works with unspecified events, etc...
tfr = EpochsTFRArray(
info=tfr.info, data=tfr.data, times=tfr.times, freqs=tfr.freqs
)
# test integer and slice
subset_ints = tfr[[0, 1, 2]]
subset_slice = tfr[:3]
Expand Down
8 changes: 7 additions & 1 deletion mne/time_frequency/tfr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3105,8 +3105,14 @@ def __setstate__(self, state):
).squeeze(axis=0)
self.events = state.get("events", _ensure_events(fake_events))
self.event_id = state.get("event_id", _check_event_id(None, self.events))
self.drop_log = state.get("drop_log", tuple())
self.selection = state.get("selection", np.arange(n_epochs))
self.drop_log = state.get(
"drop_log",
tuple(
() if k in self.selection else ("IGNORED",)
for k in range(max(len(self.events), max(self.selection) + 1))
),
)
self._bad_dropped = True # always true, need for `equalize_event_counts()`

def __next__(self, return_event_id=False):
Expand Down

0 comments on commit d814954

Please sign in to comment.