Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Add option to store and return TFR taper weights #12910

Open
wants to merge 38 commits into
base: main
Choose a base branch
from

Conversation

tsbinns
Copy link
Contributor

@tsbinns tsbinns commented Oct 22, 2024

Reference issue (if any)

PR for #12851

What does this implement/fix?

Adds an option to return taper weights for complex and phase outputs of the multitaper method in tfr_array_multitaper(), and also ensures taper weights are stored in TFR objects.

Additional information

When working on this, I discovered a couple of other issues with the per-taper TFR implementations (#12851 (comment)), including the fact that the TFR object plotting methods and to_data_frame methods do not account for a taper dimension, leading to errors. Wasn't sure if people want me to also address these here or in a separate PR.

@@ -302,12 +306,15 @@ def _make_dpss(
real_offset = Wk.mean()
Wk -= real_offset
Wk /= np.sqrt(0.5) * np.linalg.norm(Wk.ravel())
Ck = np.sqrt(conc[m])
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This I am somewhat unsure on. The existing implementation is to just use conc as-is, however in the MNE-Connectivity implementation that sqrt is taken: https://github.com/mne-tools/mne-connectivity/blob/97147a57eefb36a5c9680e539fdc6343a1183f20/mne_connectivity/spectral/time.py#L825

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am also unsure on this point. We should ask @ruuskas (who wrote the implementation in MNE-Connectivity) and @larsoner (who wrote the SciPy DPSS implementation) to weigh in.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed for the PSD computation that the square root of the weights is also taken, so I think this is okay:

weights = np.sqrt(eigvals)[np.newaxis, :, np.newaxis]

@tsbinns
Copy link
Contributor Author

tsbinns commented Oct 22, 2024

I'm also somewhat confused about the design of the _make_dpss function:

for m in range(n_taps):
Wm = list()
Cm = list()
for k, f in enumerate(freqs):
if len(n_cycles) != 1:
this_n_cycles = n_cycles[k]
else:
this_n_cycles = n_cycles[0]
t_win = this_n_cycles / float(f)
t = np.arange(0.0, t_win, 1.0 / sfreq)
# Making sure wavelets are centered before tapering
oscillation = np.exp(2.0 * 1j * np.pi * f * (t - t_win / 2.0))
# Get dpss tapers
tapers, conc = dpss_windows(
t.shape[0], time_bandwidth / 2.0, n_taps, sym=False
)
Wk = oscillation * tapers[m]
if zero_mean: # to make it zero mean
real_offset = Wk.mean()
Wk -= real_offset
Wk /= np.sqrt(0.5) * np.linalg.norm(Wk.ravel())
Ck = np.sqrt(conc[m])
Wm.append(Wk)
Cm.append(Ck)
Ws.append(Wm)
Cs.append(Cm)

It is looping over tapers, and then over frequencies. However, the dpss_windows function it calls internally provides the tapers and their weights for all tapers of a given frequency.

Would it not be more efficient to only loop over frequencies and take advantage of the fact that this will also return information for each taper?

@tsbinns
Copy link
Contributor Author

tsbinns commented Oct 22, 2024

I also have a question regarding testing: for the I/O tests, we're reading TFR objects that do not have a weights property (just gets assigned to None) when loaded. Do I need to create new TFR objects that actually have some weights, or is the current test sufficient?

Apart from this there are still some tests I need to expand.

mne/time_frequency/multitaper.py Outdated Show resolved Hide resolved
mne/time_frequency/tfr.py Show resolved Hide resolved
@@ -302,12 +306,15 @@ def _make_dpss(
real_offset = Wk.mean()
Wk -= real_offset
Wk /= np.sqrt(0.5) * np.linalg.norm(Wk.ravel())
Ck = np.sqrt(conc[m])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am also unsure on this point. We should ask @ruuskas (who wrote the implementation in MNE-Connectivity) and @larsoner (who wrote the SciPy DPSS implementation) to weigh in.

mne/time_frequency/tfr.py Show resolved Hide resolved
@tsbinns
Copy link
Contributor Author

tsbinns commented Oct 29, 2024

Thanks for the review @drammock! I will sort out those remaining tests, although I'm in the process of moving at the moment so it might not be for some days.

Regarding those issues I came across with TFR multitapers and converting to dataframes / plotting: would you like me to incorporate that into this PR?

@drammock
Copy link
Member

drammock commented Dec 9, 2024

Do I need to create new TFR objects that actually have some weights, or is the current test sufficient?

Yes I think we should. most (all?) of them are created by pytest fixtures at present. I see 3 options:

  1. tweak the fixtures to always return TFRs that have weights.
  2. when you want to test something specific to weights, monkey-patch some weights (and a taper dim) into the object at the start of the test
  3. write a new fixture (or parametrize an existing one) so that you can get TFRs with/without weights at need.

To really test thoroughly, option (2) is probably best, because then you can also patch in things that are expected to fail, and test that they do fail in the expected way.

@@ -1392,7 +1421,6 @@ def __setstate__(self, state):

defaults = dict(
method="unknown",
dims=("epoch", "channel", "freq", "time")[-state["data"].ndim :],
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have removed dims being set in BaseTFR since the possibility of the optional epoch and taper dimensions makes it really difficult to disentangle here. It's much easier to handle this in the individual RawTFR, EpochsTFR, and AverageTFR classes.

Comment on lines +2909 to +2913
# Set dims now since optional tapers makes it difficult to disentangle later
state["dims"] = ("channel",)
if state["data"].ndim == 4:
state["dims"] += ("taper",)
state["dims"] += ("freq", "time")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Example of handling dims in the AverageTFR class where only one dimension (taper) is optional.

Comment on lines +3278 to +3286

Averaging is not supported for data containing a taper dimension.
"""
if "taper" in self._dims:
raise NotImplementedError(
"Averaging multitaper tapers across epochs, frequencies, or times is "
"not supported. If averaging across epochs, consider averaging the "
"epochs before computing the complex/phase spectrum."
)
Copy link
Contributor Author

@tsbinns tsbinns Dec 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In terms of averaging for data with tapers, I went for the same approach we're using for Spectrum and just disallowing this.

I don't think this is an API change requiring a deprecation cycle since:

  1. the docstring expects the data to not have a taper dimension, e.g. If callable, must take a NumPy array of shape (n_epochs, n_channels, n_freqs, n_times).
  2. trying to call this method on an object with a taper dimension would raise an uncaught error: n_epochs, n_channels, n_freqs, n_times = self.data.shape (wouldn't be able to unpack this properly).

So explicitly preventing this method being called with a taper dimension doesn't change current behaviour, it just gives a nicer error as to why this can't be done.

Comment on lines 3942 to +3953
Notes
-----
Aggregating multitaper TFR datasets with a taper dimension such as for complex or
phase data is not supported.

.. versionadded:: 0.11.0
"""
if any("taper" in tfr._dims for tfr in all_tfr):
raise NotImplementedError(
"Aggregating multitaper tapers across TFR datasets is not supported."
)

Copy link
Contributor Author

@tsbinns tsbinns Dec 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a similar case to averaging for the time_frequency.combine_tfr() function (which also gets called by the grand_average() function).

However, unlike the EpochsTFR.average() method, this could be considered an API change since combine_tfr() should currently run with taper data. Does preventing this use case require a deprecation cycle?

On a side note, I noticed that while a public function, combine_tfr() is not listed in the API (the equivalent combine_evoked() is). Is this an oversight or an intended omission?

@tsbinns
Copy link
Contributor Author

tsbinns commented Dec 11, 2024

Those recent pushes added support for data with a tapers dimension in the ...TFRArray objects which was no fully accounted for before.

@tsbinns
Copy link
Contributor Author

tsbinns commented Dec 11, 2024

Now to_data_frame works for data with a tapers dimension (alongside unit tests).

Just sorting the issues with plotting to go!

@tsbinns
Copy link
Contributor Author

tsbinns commented Dec 12, 2024

Just looking into the sorting the power this morning and I am a little confused by the procedure being used to convert the complex taper coeffs into power, as it seems that no taper weights are ever applied. I opened an issue to try and figure out if this is a mistake, or a misunderstanding on my part: #13023

Copy link
Contributor Author

@tsbinns tsbinns left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Latest push adds support for plotting of data with a taper dimension (aggregates over tapers before plotting and converts to power (if complex coeffs) or keeps as phase data).
Also has test coverage.

Comment on lines +1852 to +1862
@pytest.mark.parametrize("output", ("complex", "phase"))
def test_tfr_topo_plotting_multitaper_complex_phase(output, evoked):
"""Test plot_joint/topo/topomap() for data with a taper dimension."""
# Compute TFR with taper dimension
tfr = evoked.compute_tfr(
method="multitaper", freqs=freqs_linspace, n_cycles=4, output=output
)
# Check that plotting works
tfr.plot_joint(topomap_args=dict(res=8, contours=0, sensors=False)) # for speed
tfr.plot_topo()
tfr.plot_topomap()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Basic test that just checks whether the code runs, but it covers the lines where changes to topo-related plotting were made, and other tests deal with non-default method params.

Comment on lines +862 to +878
@pytest.mark.parametrize("output", ("complex", "phase"))
def test_plot_multitaper_complex_phase(output):
"""Test TFR plotting of data with a taper dimension."""
# Create example data with a taper dimension
n_chans, n_tapers, n_freqs, n_times = (3, 4, 2, 3)
data = np.random.rand(n_chans, n_tapers, n_freqs, n_times)
if output == "complex":
data = data + np.random.rand(*data.shape) * 1j # add imaginary data
times = np.arange(n_times)
freqs = np.arange(n_freqs)
weights = np.random.rand(n_tapers, n_freqs)
info = mne.create_info(n_chans, 1000.0, "eeg")
tfr = AverageTFRArray(
info=info, data=data, times=times, freqs=freqs, weights=weights
)
# Check that plotting works
tfr.plot()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, a pretty basic test that just checks whether plotting code runs, but covers the changes and non-default params tested elswehere.

Comment on lines -2467 to +2552
# TODO this is the only remaining call to _preproc_tfr; should be refactored
# (to use _prep_data_for_plot?)
data, times, freqs, vmin, vmax = _preproc_tfr(
# baseline, crop, convert complex to power, aggregate tapers, and dB scaling
data, times, freqs = _prep_data_for_plot(
data,
times,
freqs,
tmin,
tmax,
fmin,
fmax,
mode,
baseline,
vmin,
vmax,
dB,
info["sfreq"],
tmin=tmin,
tmax=tmax,
fmin=fmin,
fmax=fmax,
baseline=baseline,
mode=mode,
dB=dB,
taper_weights=self.weights,
verbose=verbose,
)
# get vlims
vmin, vmax = _setup_vmin_vmax(data, vmin, vmax)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seemed like as good as time as any to refactor and replace the _preproc_tfr call with _prep_data_for_plot where changes for handling data with a taper dimension have been made.

Comment on lines +1893 to +1903
# handle unaggregated multitaper (complex or phase multitaper data)
if tfr.weights is not None: # assumes a taper dimension
logger.info("Aggregating multitaper estimates before plotting...")
weights = tfr.weights[np.newaxis, :, :, np.newaxis] # add channel & time dims
data = weights * data
if np.iscomplexobj(data): # complex coefficients → power
data *= data.conj()
data = data.real.sum(axis=1)
data *= 2 / (weights * weights.conj()).real.sum(axis=1)
else: # tapered phase data → weighted phase data
data = data.mean(axis=1)
Copy link
Contributor Author

@tsbinns tsbinns Dec 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also the case that viz.plot_tfr_topomap() needs to be able to handle data with a tapers dim. Unfortunately circular imports mean the code for handling taper dims from tfr.py can't be used here, so there's a bit of code repetition.

Comment on lines +4316 to +4317
else: # tapered phase data → weighted phase data
data = (data * taper_weights[np.newaxis, :, :, np.newaxis]).mean(axis=1)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is my guess at aggregating over tapers for phase data. Can anyone confirm if this is correct?

Comment on lines +4342 to +4345
tfr = weights * x_mt
tfr *= tfr.conj()
tfr = tfr.real.sum(axis=1)
tfr *= 2 / (weights * weights.conj()).real.sum(axis=1)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This aggregation over tapers for the complex spectra follows the same procedure for computing PSDs and how we handle TFR data in MNE-Connectivity.
However, it does differ to how this aggregation is handled elsewhere in the TFR classes (see #13023), so it would be nice to clarify the correct approach.

@larsoner larsoner added this to the 1.10 milestone Dec 16, 2024
@tsbinns
Copy link
Contributor Author

tsbinns commented Dec 19, 2024

Just looking into the sorting the power this morning and I am a little confused by the procedure being used to convert the complex taper coeffs into power, as it seems that no taper weights are ever applied. I opened an issue to try and figure out if this is a mistake, or a misunderstanding on my part: #13023

As discussed there this is a bug but will be addressed in a separate PR once this is merged.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants