Skip to content

Commit

Permalink
ENH: Add image_kwargs to report.add_epochs (mne-tools#12443)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Richard Höchenberger <[email protected]>
Co-authored-by: Eric Larson <[email protected]>
  • Loading branch information
4 people authored Feb 16, 2024
1 parent 85ca0ed commit 5e23fe0
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 6 deletions.
1 change: 1 addition & 0 deletions doc/changes/devel/12443.newfeature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add option to pass ``image_kwargs`` to :class:`mne.Report.add_epochs` to allow adjusting e.g. ``vmin`` and ``vmax`` of the epochs image in the report, by `Sophie Herbst`_.
19 changes: 17 additions & 2 deletions mne/report/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -1092,6 +1092,7 @@ def add_epochs(
*,
psd=True,
projs=None,
image_kwargs=None,
topomap_kwargs=None,
drop_log_ignore=("IGNORED",),
tags=("epochs",),
Expand Down Expand Up @@ -1120,6 +1121,11 @@ def add_epochs(
If ``True``, add PSD plots based on all ``epochs``. If ``False``,
do not add PSD plots.
%(projs_report)s
image_kwargs : dict | None
Keyword arguments to pass to the "epochs image"-generating
function (:meth:`mne.Epochs.plot_image`).
.. versionadded:: 1.7
%(topomap_kwargs)s
drop_log_ignore : array-like of str
The drop reasons to ignore when creating the drop log bar plot.
Expand All @@ -1130,14 +1136,15 @@ def add_epochs(
Notes
-----
.. versionadded:: 0.24.0
.. versionadded:: 0.24
"""
tags = _check_tags(tags)
add_projs = self.projs if projs is None else projs
self._add_epochs(
epochs=epochs,
psd=psd,
add_projs=add_projs,
image_kwargs=image_kwargs,
topomap_kwargs=topomap_kwargs,
drop_log_ignore=drop_log_ignore,
section=title,
Expand Down Expand Up @@ -3900,6 +3907,7 @@ def _add_epochs(
epochs,
psd,
add_projs,
image_kwargs,
topomap_kwargs,
drop_log_ignore,
image_format,
Expand Down Expand Up @@ -3934,9 +3942,16 @@ def _add_epochs(
ch_types = _get_data_ch_types(epochs)
epochs.load_data()

if image_kwargs is None:
image_kwargs = dict()

for ch_type in ch_types:
with use_log_level(_verbose_safe_false(level="error")):
figs = epochs.copy().pick(ch_type, verbose=False).plot_image(show=False)
figs = (
epochs.copy()
.pick(ch_type, verbose=False)
.plot_image(show=False, **image_kwargs)
)

assert len(figs) == 1
fig = figs[0]
Expand Down
18 changes: 14 additions & 4 deletions mne/report/tests/test_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -884,6 +884,8 @@ def test_manual_report_2d(tmp_path, invisible_fig):
raw = read_raw_fif(raw_fname)
raw.pick(raw.ch_names[:6]).crop(10, None)
raw.info.normalize_proj()
raw_non_preloaded = raw.copy()
raw.load_data()
cov = read_cov(cov_fname)
cov = pick_channels_cov(cov, raw.ch_names)
events = read_events(events_fname)
Expand All @@ -899,17 +901,24 @@ def test_manual_report_2d(tmp_path, invisible_fig):
events=events, event_id=event_id, tmin=-0.2, tmax=0.5, sfreq=raw.info["sfreq"]
)
epochs_without_metadata = Epochs(
raw=raw, events=events, event_id=event_id, baseline=None
raw=raw,
events=events,
event_id=event_id,
baseline=None,
decim=10,
verbose="error",
)
epochs_with_metadata = Epochs(
raw=raw,
events=metadata_events,
event_id=metadata_event_id,
baseline=None,
metadata=metadata,
decim=10,
verbose="error",
)
evokeds = read_evokeds(evoked_fname)
evoked = evokeds[0].pick("eeg")
evoked = evokeds[0].pick("eeg").decimate(10, verbose="error")

with pytest.warns(ConvergenceWarning, match="did not converge"):
ica = ICA(n_components=3, max_iter=1, random_state=42).fit(
Expand All @@ -927,6 +936,7 @@ def test_manual_report_2d(tmp_path, invisible_fig):
tags=("epochs",),
psd=False,
projs=False,
image_kwargs=dict(colorbar=False),
)
r.add_epochs(
epochs=epochs_without_metadata, title="my epochs 2", psd=1, projs=False
Expand Down Expand Up @@ -963,11 +973,11 @@ def test_manual_report_2d(tmp_path, invisible_fig):
)
r.add_ica(ica=ica, title="my ica", inst=None)
with pytest.raises(RuntimeError, match="not preloaded"):
r.add_ica(ica=ica, title="ica", inst=raw)
r.add_ica(ica=ica, title="ica", inst=raw_non_preloaded)
r.add_ica(
ica=ica,
title="my ica with raw inst",
inst=raw.copy().load_data(),
inst=raw,
picks=[2],
ecg_evoked=ica_ecg_evoked,
eog_evoked=ica_eog_evoked,
Expand Down

0 comments on commit 5e23fe0

Please sign in to comment.