Skip to content

Commit

Permalink
fix scaling in Spectrum.plot(amplitude=True, dB=True) (#13036)
Browse files Browse the repository at this point in the history
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
  • Loading branch information
drammock and autofix-ci[bot] authored Dec 17, 2024
1 parent dcd2625 commit 41dbdd5
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 1 deletion.
1 change: 1 addition & 0 deletions doc/changes/devel/13036.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix plot scaling for :meth:`Spectrum.plot(dB=True, amplitude=True) <mne.time_frequency.Spectrum.plot>`, by `Daniel McCloy`_.
26 changes: 26 additions & 0 deletions mne/time_frequency/tests/test_spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,3 +627,29 @@ def test_plot_spectrum_array_with_bads():
spectrum.get_data(exclude=()), spectrum.info, spectrum.freqs
)
spectrum2.plot(spatial_colors=False)


@pytest.mark.parametrize("dB", (False, True))
@pytest.mark.parametrize("amplitude", (False, True))
def test_plot_spectrum_dB(raw_spectrum, dB, amplitude):
"""Test that we properly handle amplitude/power and dB."""
idx = 7
power = 3
freqs = np.linspace(1, 100, 100)
data = np.full((1, freqs.size), np.finfo(float).tiny)
data[0, idx] = power
info = create_info(ch_names=["delta"], sfreq=1000, ch_types="eeg")
psd = SpectrumArray(data=data, info=info, freqs=freqs)
with pytest.warns(RuntimeWarning, match="Channel locations not available"):
fig = psd.plot(dB=dB, amplitude=amplitude)
trace = list(
filter(lambda x: len(x.get_data()[0]) == len(freqs), fig.axes[0].lines)
)[0]
got = trace.get_data()[1][idx]
want = power * 1e12 # scaling for EEG (V → μV), squared
if amplitude:
want = np.sqrt(want)
if dB:
want = (20 if amplitude else 10) * np.log10(want)

assert want == got, f"expected {want}, got {got}"
4 changes: 3 additions & 1 deletion mne/viz/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2390,14 +2390,16 @@ def _convert_psds(
np.sqrt(psds, out=psds)
psds *= scaling
ylabel = rf"$\mathrm{{{unit}/\sqrt{{Hz}}}}$"
coef = 20
else:
psds *= scaling * scaling
if "/" in unit:
unit = f"({unit})"
ylabel = rf"$\mathrm{{{unit}²/Hz}}$"
coef = 10
if dB:
np.log10(np.maximum(psds, np.finfo(float).tiny), out=psds)
psds *= 10
psds *= coef
ylabel = r"$\mathrm{dB}\ $" + ylabel
ylabel = "Power (" + ylabel if estimate == "power" else "Amplitude (" + ylabel
ylabel += ")"
Expand Down

0 comments on commit 41dbdd5

Please sign in to comment.