Skip to content

Commit

Permalink
updating TFR classes (mne-tools#11282)
Browse files Browse the repository at this point in the history
Co-authored-by: Eric Larson <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Mar 20, 2024
1 parent cf6e13c commit 925f522
Show file tree
Hide file tree
Showing 44 changed files with 4,565 additions and 2,388 deletions.
5 changes: 5 additions & 0 deletions doc/api/time_frequency.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@ Time-Frequency
:toctree: ../generated/

AverageTFR
AverageTFRArray
BaseTFR
EpochsTFR
EpochsTFRArray
RawTFR
RawTFRArray
CrossSpectralDensity
Spectrum
SpectrumArray
Expand Down
1 change: 1 addition & 0 deletions doc/changes/devel/11282.apichange.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
The default value of the ``zero_mean`` parameter of :func:`mne.time_frequency.tfr_array_morlet` will change from ``False`` to ``True`` in version 1.8, for consistency with related functions. By `Daniel McCloy`_.
1 change: 1 addition & 0 deletions doc/changes/devel/11282.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fixes to interactivity in time-frequency objects: the rectangle selector now works on TFR image plots of gradiometer data; and in ``TFR.plot_joint()`` plots, the colormap limits of interactively-generated topomaps match the colormap limits of the main plot. By `Daniel McCloy`_.
1 change: 1 addition & 0 deletions doc/changes/devel/11282.newfeature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
New class :class:`mne.time_frequency.RawTFR` and new methods :meth:`mne.io.Raw.compute_tfr`, :meth:`mne.Epochs.compute_tfr`, and :meth:`mne.Evoked.compute_tfr`. These new methods supersede functions :func:`mne.time_frequency.tfr_morlet`, and :func:`mne.time_frequency.tfr_multitaper`, and :func:`mne.time_frequency.tfr_stockwell`, which are now considered "legacy" functions. By `Daniel McCloy`_.
4 changes: 4 additions & 0 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,11 @@
"EvokedArray": "mne.EvokedArray",
"BiHemiLabel": "mne.BiHemiLabel",
"AverageTFR": "mne.time_frequency.AverageTFR",
"AverageTFRArray": "mne.time_frequency.AverageTFRArray",
"EpochsTFR": "mne.time_frequency.EpochsTFR",
"EpochsTFRArray": "mne.time_frequency.EpochsTFRArray",
"RawTFR": "mne.time_frequency.RawTFR",
"RawTFRArray": "mne.time_frequency.RawTFRArray",
"Raw": "mne.io.Raw",
"ICA": "mne.preprocessing.ICA",
"Covariance": "mne.Covariance",
Expand Down
18 changes: 10 additions & 8 deletions examples/decoding/decoding_csp_timefreq.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from mne.datasets import eegbci
from mne.decoding import CSP
from mne.io import concatenate_raws, read_raw_edf
from mne.time_frequency import AverageTFR
from mne.time_frequency import AverageTFRArray

# %%
# Set parameters and read data
Expand Down Expand Up @@ -173,13 +173,15 @@
# Plot time-frequency results

# Set up time frequency object
av_tfr = AverageTFR(
create_info(["freq"], sfreq),
tf_scores[np.newaxis, :],
centered_w_times,
freqs[1:],
1,
av_tfr = AverageTFRArray(
info=create_info(["freq"], sfreq),
data=tf_scores[np.newaxis, :],
times=centered_w_times,
freqs=freqs[1:],
nave=1,
)

chance = np.mean(y) # set chance level to white in the plot
av_tfr.plot([0], vmin=chance, title="Time-Frequency Decoding Scores", cmap=plt.cm.Reds)
av_tfr.plot(
[0], vlim=(chance, None), title="Time-Frequency Decoding Scores", cmap=plt.cm.Reds
)
6 changes: 3 additions & 3 deletions examples/inverse/dics_epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import mne
from mne.beamformer import apply_dics_tfr_epochs, make_dics
from mne.datasets import somato
from mne.time_frequency import csd_tfr, tfr_morlet
from mne.time_frequency import csd_tfr

print(__doc__)

Expand Down Expand Up @@ -67,8 +67,8 @@
# decomposition for each epoch. We must pass ``output='complex'`` if we wish to
# use this TFR later with a DICS beamformer. We also pass ``average=False`` to
# compute the TFR for each individual epoch.
epochs_tfr = tfr_morlet(
epochs, freqs, n_cycles=5, return_itc=False, output="complex", average=False
epochs_tfr = epochs.compute_tfr(
"morlet", freqs, n_cycles=5, return_itc=False, output="complex", average=False
)

# crop either side to use a buffer to remove edge artifact
Expand Down
5 changes: 2 additions & 3 deletions examples/time_frequency/time_frequency_erds.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
from mne.datasets import eegbci
from mne.io import concatenate_raws, read_raw_edf
from mne.stats import permutation_cluster_1samp_test as pcluster_test
from mne.time_frequency import tfr_multitaper

# %%
# First, we load and preprocess the data. We use runs 6, 10, and 14 from
Expand Down Expand Up @@ -96,8 +95,8 @@

# %%
# Finally, we perform time/frequency decomposition over all epochs.
tfr = tfr_multitaper(
epochs,
tfr = epochs.compute_tfr(
method="multitaper",
freqs=freqs,
n_cycles=freqs,
use_fft=True,
Expand Down
80 changes: 38 additions & 42 deletions examples/time_frequency/time_frequency_simulated.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,8 @@
from matplotlib import pyplot as plt

from mne import Epochs, create_info
from mne.baseline import rescale
from mne.io import RawArray
from mne.time_frequency import (
AverageTFR,
tfr_array_morlet,
tfr_morlet,
tfr_multitaper,
tfr_stockwell,
)
from mne.viz import centers_to_edges
from mne.time_frequency import AverageTFRArray, EpochsTFRArray, tfr_array_morlet

print(__doc__)

Expand Down Expand Up @@ -112,21 +104,21 @@
"Sim: Less time smoothing,\nmore frequency smoothing",
],
):
power = tfr_multitaper(
epochs,
power = epochs.compute_tfr(
method="multitaper",
freqs=freqs,
n_cycles=n_cycles,
time_bandwidth=time_bandwidth,
return_itc=False,
average=True,
)
ax.set_title(title)
# Plot results. Baseline correct based on first 100 ms.
power.plot(
[0],
baseline=(0.0, 0.1),
mode="mean",
vmin=vmin,
vmax=vmax,
vlim=(vmin, vmax),
axes=ax,
show=False,
colorbar=False,
Expand All @@ -146,7 +138,7 @@
fig, axs = plt.subplots(1, 3, figsize=(15, 5), sharey=True, layout="constrained")
fmin, fmax = freqs[[0, -1]]
for width, ax in zip((0.2, 0.7, 3.0), axs):
power = tfr_stockwell(epochs, fmin=fmin, fmax=fmax, width=width)
power = epochs.compute_tfr(method="stockwell", freqs=(fmin, fmax), width=width)
power.plot(
[0], baseline=(0.0, 0.1), mode="mean", axes=ax, show=False, colorbar=False
)
Expand All @@ -164,13 +156,14 @@
fig, axs = plt.subplots(1, 3, figsize=(15, 5), sharey=True, layout="constrained")
all_n_cycles = [1, 3, freqs / 2.0]
for n_cycles, ax in zip(all_n_cycles, axs):
power = tfr_morlet(epochs, freqs=freqs, n_cycles=n_cycles, return_itc=False)
power = epochs.compute_tfr(
method="morlet", freqs=freqs, n_cycles=n_cycles, return_itc=False, average=True
)
power.plot(
[0],
baseline=(0.0, 0.1),
mode="mean",
vmin=vmin,
vmax=vmax,
vlim=(vmin, vmax),
axes=ax,
show=False,
colorbar=False,
Expand All @@ -190,7 +183,9 @@
fig, axs = plt.subplots(1, 3, figsize=(15, 5), sharey=True, layout="constrained")
bandwidths = [1.0, 2.0, 4.0]
for bandwidth, ax in zip(bandwidths, axs):
data = np.zeros((len(ch_names), freqs.size, epochs.times.size), dtype=complex)
data = np.zeros(
(len(epochs), len(ch_names), freqs.size, epochs.times.size), dtype=complex
)
for idx, freq in enumerate(freqs):
# Filter raw data and re-epoch to avoid the filter being longer than
# the epoch data for low frequencies and short epochs, such as here.
Expand All @@ -210,17 +205,13 @@
epochs_hilb = Epochs(
raw_filter, events, tmin=0, tmax=n_times / sfreq, baseline=(0, 0.1)
)
tfr_data = epochs_hilb.get_data()
tfr_data = tfr_data * tfr_data.conj() # compute power
tfr_data = np.mean(tfr_data, axis=0) # average over epochs
data[:, idx] = tfr_data
power = AverageTFR(info, data, epochs.times, freqs, nave=n_epochs)
power.plot(
data[:, :, idx] = epochs_hilb.get_data()
power = EpochsTFRArray(epochs.info, data, epochs.times, freqs, method="hilbert")
power.average().plot(
[0],
baseline=(0.0, 0.1),
mode="mean",
vmin=-0.1,
vmax=0.1,
vlim=(0, 0.1),
axes=ax,
show=False,
colorbar=False,
Expand All @@ -241,17 +232,16 @@
# :class:`mne.time_frequency.EpochsTFR` is returned.

n_cycles = freqs / 2.0
power = tfr_morlet(
epochs, freqs=freqs, n_cycles=n_cycles, return_itc=False, average=False
power = epochs.compute_tfr(
method="morlet", freqs=freqs, n_cycles=n_cycles, return_itc=False, average=False
)
print(type(power))
avgpower = power.average()
avgpower.plot(
[0],
baseline=(0.0, 0.1),
mode="mean",
vmin=vmin,
vmax=vmax,
vlim=(vmin, vmax),
title="Using Morlet wavelets and EpochsTFR",
show=False,
)
Expand All @@ -260,23 +250,29 @@
# Operating on arrays
# -------------------
#
# MNE also has versions of the functions above which operate on numpy arrays
# instead of MNE objects. They expect inputs of the shape
# ``(n_epochs, n_channels, n_times)``. They will also return a numpy array
# of shape ``(n_epochs, n_channels, n_freqs, n_times)``.
# MNE-Python also has functions that operate on :class:`NumPy arrays <numpy.ndarray>`
# instead of MNE-Python objects. These are :func:`~mne.time_frequency.tfr_array_morlet`
# and :func:`~mne.time_frequency.tfr_array_multitaper`. They expect inputs of the shape
# ``(n_epochs, n_channels, n_times)`` and return an array of shape
# ``(n_epochs, n_channels, n_freqs, n_times)`` (or optionally, can collapse the epochs
# dimension if you want average power or inter-trial coherence; see ``output`` param).

power = tfr_array_morlet(
epochs.get_data(),
sfreq=epochs.info["sfreq"],
freqs=freqs,
n_cycles=n_cycles,
output="avg_power",
zero_mean=False,
)
# Put it into a TFR container for easy plotting
tfr = AverageTFRArray(
info=epochs.info, data=power, times=epochs.times, freqs=freqs, nave=len(epochs)
)
tfr.plot(
baseline=(0.0, 0.1),
picks=[0],
mode="mean",
vlim=(vmin, vmax),
title="TFR calculated on a NumPy array",
)
# Baseline the output
rescale(power, epochs.times, (0.0, 0.1), mode="mean", copy=False)
fig, ax = plt.subplots(layout="constrained")
x, y = centers_to_edges(epochs.times * 1000, freqs)
mesh = ax.pcolormesh(x, y, power[0], cmap="RdBu_r", vmin=vmin, vmax=vmax)
ax.set_title("TFR calculated on a numpy array")
ax.set(ylim=freqs[[0, -1]], xlabel="Time (ms)")
fig.colorbar(mesh)
4 changes: 2 additions & 2 deletions mne/beamformer/tests/test_dics.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from mne.io import read_info
from mne.proj import compute_proj_evoked, make_projector
from mne.surface import _compute_nearest
from mne.time_frequency import CrossSpectralDensity, EpochsTFR, csd_morlet, csd_tfr
from mne.time_frequency import CrossSpectralDensity, EpochsTFRArray, csd_morlet, csd_tfr
from mne.time_frequency.csd import _sym_mat_to_vector
from mne.transforms import apply_trans, invert_transform
from mne.utils import catch_logging, object_diff
Expand Down Expand Up @@ -727,7 +727,7 @@ def test_apply_dics_tfr(return_generator):
data = rng.random((n_epochs, n_chans, len(freqs), n_times))
data *= 1e-6
data = data + data * 1j # add imag. component to simulate phase
epochs_tfr = EpochsTFR(info, data, times=times, freqs=freqs)
epochs_tfr = EpochsTFRArray(info=info, data=data, times=times, freqs=freqs)

# Create a DICS beamformer and convert the EpochsTFR to source space.
csd = csd_tfr(epochs_tfr)
Expand Down
10 changes: 3 additions & 7 deletions mne/channels/channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,15 +153,15 @@ def equalize_channels(instances, copy=True, verbose=None):
from ..evoked import Evoked
from ..forward import Forward
from ..io import BaseRaw
from ..time_frequency import CrossSpectralDensity, _BaseTFR
from ..time_frequency import BaseTFR, CrossSpectralDensity

# Instances need to have a `ch_names` attribute and a `pick_channels`
# method that supports `ordered=True`.
allowed_types = (
BaseRaw,
BaseEpochs,
Evoked,
_BaseTFR,
BaseTFR,
Forward,
Covariance,
CrossSpectralDensity,
Expand Down Expand Up @@ -607,8 +607,6 @@ def drop_channels(self, ch_names, on_missing="raise"):
def _pick_drop_channels(self, idx, *, verbose=None):
# avoid circular imports
from ..io import BaseRaw
from ..time_frequency import AverageTFR, EpochsTFR
from ..time_frequency.spectrum import BaseSpectrum

msg = "adding, dropping, or reordering channels"
if isinstance(self, BaseRaw):
Expand All @@ -633,10 +631,8 @@ def _pick_drop_channels(self, idx, *, verbose=None):
if mat is not None:
setattr(self, key, mat[idx][:, idx])

if isinstance(self, BaseSpectrum):
if hasattr(self, "_dims"): # Spectrum and "new-style" TFRs
axis = self._dims.index("channel")
elif isinstance(self, (AverageTFR, EpochsTFR)):
axis = -3
else: # All others (Evoked, Epochs, Raw) have chs axis=-2
axis = -2
if hasattr(self, "_data"): # skip non-preloaded Raw
Expand Down
Loading

0 comments on commit 925f522

Please sign in to comment.