-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ENH] Add option to store and return TFR taper weights #12910
base: main
Are you sure you want to change the base?
Conversation
@@ -302,12 +306,15 @@ def _make_dpss( | |||
real_offset = Wk.mean() | |||
Wk -= real_offset | |||
Wk /= np.sqrt(0.5) * np.linalg.norm(Wk.ravel()) | |||
Ck = np.sqrt(conc[m]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This I am somewhat unsure on. The existing implementation is to just use conc
as-is, however in the MNE-Connectivity implementation that sqrt is taken: https://github.com/mne-tools/mne-connectivity/blob/97147a57eefb36a5c9680e539fdc6343a1183f20/mne_connectivity/spectral/time.py#L825
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I noticed for the PSD computation that the square root of the weights is also taken, so I think this is okay:
mne-python/mne/time_frequency/multitaper.py
Line 412 in b329515
weights = np.sqrt(eigvals)[np.newaxis, :, np.newaxis] |
I'm also somewhat confused about the design of the mne-python/mne/time_frequency/tfr.py Lines 285 to 315 in 82fc2f7
It is looping over tapers, and then over frequencies. However, the Would it not be more efficient to only loop over frequencies and take advantage of the fact that this will also return information for each taper? |
I also have a question regarding testing: for the I/O tests, we're reading Apart from this there are still some tests I need to expand. |
@@ -302,12 +306,15 @@ def _make_dpss( | |||
real_offset = Wk.mean() | |||
Wk -= real_offset | |||
Wk /= np.sqrt(0.5) * np.linalg.norm(Wk.ravel()) | |||
Ck = np.sqrt(conc[m]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the review @drammock! I will sort out those remaining tests, although I'm in the process of moving at the moment so it might not be for some days. Regarding those issues I came across with TFR multitapers and converting to dataframes / plotting: would you like me to incorporate that into this PR? |
Yes I think we should. most (all?) of them are created by pytest fixtures at present. I see 3 options:
To really test thoroughly, option (2) is probably best, because then you can also patch in things that are expected to fail, and test that they do fail in the expected way. |
@@ -1392,7 +1421,6 @@ def __setstate__(self, state): | |||
|
|||
defaults = dict( | |||
method="unknown", | |||
dims=("epoch", "channel", "freq", "time")[-state["data"].ndim :], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have removed dims
being set in BaseTFR
since the possibility of the optional epoch
and taper
dimensions makes it really difficult to disentangle here. It's much easier to handle this in the individual RawTFR
, EpochsTFR
, and AverageTFR
classes.
# Set dims now since optional tapers makes it difficult to disentangle later | ||
state["dims"] = ("channel",) | ||
if state["data"].ndim == 4: | ||
state["dims"] += ("taper",) | ||
state["dims"] += ("freq", "time") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Example of handling dims in the AverageTFR
class where only one dimension (taper
) is optional.
|
||
Averaging is not supported for data containing a taper dimension. | ||
""" | ||
if "taper" in self._dims: | ||
raise NotImplementedError( | ||
"Averaging multitaper tapers across epochs, frequencies, or times is " | ||
"not supported. If averaging across epochs, consider averaging the " | ||
"epochs before computing the complex/phase spectrum." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In terms of averaging for data with tapers, I went for the same approach we're using for Spectrum
and just disallowing this.
I don't think this is an API change requiring a deprecation cycle since:
- the docstring expects the data to not have a taper dimension, e.g.
If callable, must take a NumPy array of shape (n_epochs, n_channels, n_freqs, n_times)
. - trying to call this method on an object with a taper dimension would raise an uncaught error:
n_epochs, n_channels, n_freqs, n_times = self.data.shape
(wouldn't be able to unpack this properly).
So explicitly preventing this method being called with a taper dimension doesn't change current behaviour, it just gives a nicer error as to why this can't be done.
Notes | ||
----- | ||
Aggregating multitaper TFR datasets with a taper dimension such as for complex or | ||
phase data is not supported. | ||
|
||
.. versionadded:: 0.11.0 | ||
""" | ||
if any("taper" in tfr._dims for tfr in all_tfr): | ||
raise NotImplementedError( | ||
"Aggregating multitaper tapers across TFR datasets is not supported." | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a similar case to averaging for the time_frequency.combine_tfr()
function (which also gets called by the grand_average()
function).
However, unlike the EpochsTFR.average()
method, this could be considered an API change since combine_tfr()
should currently run with taper data. Does preventing this use case require a deprecation cycle?
On a side note, I noticed that while a public function, combine_tfr()
is not listed in the API (the equivalent combine_evoked()
is). Is this an oversight or an intended omission?
… into add_tfr_weights
Those recent pushes added support for data with a tapers dimension in the |
Now Just sorting the issues with plotting to go! |
Just looking into the sorting the power this morning and I am a little confused by the procedure being used to convert the complex taper coeffs into power, as it seems that no taper weights are ever applied. I opened an issue to try and figure out if this is a mistake, or a misunderstanding on my part: #13023 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Latest push adds support for plotting of data with a taper dimension (aggregates over tapers before plotting and converts to power (if complex coeffs) or keeps as phase data).
Also has test coverage.
@pytest.mark.parametrize("output", ("complex", "phase")) | ||
def test_tfr_topo_plotting_multitaper_complex_phase(output, evoked): | ||
"""Test plot_joint/topo/topomap() for data with a taper dimension.""" | ||
# Compute TFR with taper dimension | ||
tfr = evoked.compute_tfr( | ||
method="multitaper", freqs=freqs_linspace, n_cycles=4, output=output | ||
) | ||
# Check that plotting works | ||
tfr.plot_joint(topomap_args=dict(res=8, contours=0, sensors=False)) # for speed | ||
tfr.plot_topo() | ||
tfr.plot_topomap() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Basic test that just checks whether the code runs, but it covers the lines where changes to topo-related plotting were made, and other tests deal with non-default method params.
@pytest.mark.parametrize("output", ("complex", "phase")) | ||
def test_plot_multitaper_complex_phase(output): | ||
"""Test TFR plotting of data with a taper dimension.""" | ||
# Create example data with a taper dimension | ||
n_chans, n_tapers, n_freqs, n_times = (3, 4, 2, 3) | ||
data = np.random.rand(n_chans, n_tapers, n_freqs, n_times) | ||
if output == "complex": | ||
data = data + np.random.rand(*data.shape) * 1j # add imaginary data | ||
times = np.arange(n_times) | ||
freqs = np.arange(n_freqs) | ||
weights = np.random.rand(n_tapers, n_freqs) | ||
info = mne.create_info(n_chans, 1000.0, "eeg") | ||
tfr = AverageTFRArray( | ||
info=info, data=data, times=times, freqs=freqs, weights=weights | ||
) | ||
# Check that plotting works | ||
tfr.plot() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Again, a pretty basic test that just checks whether plotting code runs, but covers the changes and non-default params tested elswehere.
# TODO this is the only remaining call to _preproc_tfr; should be refactored | ||
# (to use _prep_data_for_plot?) | ||
data, times, freqs, vmin, vmax = _preproc_tfr( | ||
# baseline, crop, convert complex to power, aggregate tapers, and dB scaling | ||
data, times, freqs = _prep_data_for_plot( | ||
data, | ||
times, | ||
freqs, | ||
tmin, | ||
tmax, | ||
fmin, | ||
fmax, | ||
mode, | ||
baseline, | ||
vmin, | ||
vmax, | ||
dB, | ||
info["sfreq"], | ||
tmin=tmin, | ||
tmax=tmax, | ||
fmin=fmin, | ||
fmax=fmax, | ||
baseline=baseline, | ||
mode=mode, | ||
dB=dB, | ||
taper_weights=self.weights, | ||
verbose=verbose, | ||
) | ||
# get vlims | ||
vmin, vmax = _setup_vmin_vmax(data, vmin, vmax) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seemed like as good as time as any to refactor and replace the _preproc_tfr
call with _prep_data_for_plot
where changes for handling data with a taper dimension have been made.
# handle unaggregated multitaper (complex or phase multitaper data) | ||
if tfr.weights is not None: # assumes a taper dimension | ||
logger.info("Aggregating multitaper estimates before plotting...") | ||
weights = tfr.weights[np.newaxis, :, :, np.newaxis] # add channel & time dims | ||
data = weights * data | ||
if np.iscomplexobj(data): # complex coefficients → power | ||
data *= data.conj() | ||
data = data.real.sum(axis=1) | ||
data *= 2 / (weights * weights.conj()).real.sum(axis=1) | ||
else: # tapered phase data → weighted phase data | ||
data = data.mean(axis=1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also the case that viz.plot_tfr_topomap()
needs to be able to handle data with a tapers dim. Unfortunately circular imports mean the code for handling taper dims from tfr.py
can't be used here, so there's a bit of code repetition.
else: # tapered phase data → weighted phase data | ||
data = (data * taper_weights[np.newaxis, :, :, np.newaxis]).mean(axis=1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is my guess at aggregating over tapers for phase data. Can anyone confirm if this is correct?
tfr = weights * x_mt | ||
tfr *= tfr.conj() | ||
tfr = tfr.real.sum(axis=1) | ||
tfr *= 2 / (weights * weights.conj()).real.sum(axis=1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This aggregation over tapers for the complex spectra follows the same procedure for computing PSDs and how we handle TFR data in MNE-Connectivity.
However, it does differ to how this aggregation is handled elsewhere in the TFR classes (see #13023), so it would be nice to clarify the correct approach.
As discussed there this is a bug but will be addressed in a separate PR once this is merged. |
Reference issue (if any)
PR for #12851
What does this implement/fix?
Adds an option to return taper weights for complex and phase outputs of the multitaper method in
tfr_array_multitaper()
, and also ensures taper weights are stored inTFR
objects.Additional information
When working on this, I discovered a couple of other issues with the per-taper TFR implementations (#12851 (comment)), including the fact that the
TFR
object plotting methods andto_data_frame
methods do not account for a taper dimension, leading to errors. Wasn't sure if people want me to also address these here or in a separate PR.