Skip to content

Commit

Permalink
BUG: Fix bug with minimum phase filters (mne-tools#12507)
Browse files Browse the repository at this point in the history
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Daniel McCloy <[email protected]>
  • Loading branch information
3 people authored and snwnde committed Mar 20, 2024
1 parent aa9e199 commit 4595911
Show file tree
Hide file tree
Showing 6 changed files with 162 additions and 98 deletions.
5 changes: 5 additions & 0 deletions doc/changes/devel/12507.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
Fix bug where using ``phase="minimum"`` in filtering functions like
:meth:`mne.io.Raw.filter` constructed a filter half the desired length with
compromised attenuation. Now ``phase="minimum"`` has the same length and comparable
suppression as ``phase="zero"``, and the old (incorrect) behavior can be achieved
with ``phase="minimum-half"``, by `Eric Larson`_.
73 changes: 8 additions & 65 deletions mne/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
_setup_cuda_fft_resample,
_smart_pad,
)
from .fixes import minimum_phase
from .parallel import parallel_func
from .utils import (
_check_option,
Expand Down Expand Up @@ -307,39 +308,7 @@ def _overlap_add_filter(
copy=True,
pad="reflect_limited",
):
"""Filter the signal x using h with overlap-add FFTs.
Parameters
----------
x : array, shape (n_signals, n_times)
Signals to filter.
h : 1d array
Filter impulse response (FIR filter coefficients). Must be odd length
if ``phase='linear'``.
n_fft : int
Length of the FFT. If None, the best size is determined automatically.
phase : str
If ``'zero'``, the delay for the filter is compensated (and it must be
an odd-length symmetric filter). If ``'linear'``, the response is
uncompensated. If ``'zero-double'``, the filter is applied in the
forward and reverse directions. If 'minimum', a minimum-phase
filter will be used.
picks : list | None
See calling functions.
n_jobs : int | str
Number of jobs to run in parallel. Can be ``'cuda'`` if ``cupy``
is installed properly.
copy : bool
If True, a copy of x, filtered, is returned. Otherwise, it operates
on x in place.
pad : str
Padding type for ``_smart_pad``.
Returns
-------
x : array, shape (n_signals, n_times)
x filtered.
"""
"""Filter the signal x using h with overlap-add FFTs."""
# set up array for filtering, reshape to 2D, operate on last axis
x, orig_shape, picks = _prep_for_filtering(x, copy, picks)
# Extend the signal by mirroring the edges to reduce transient filter
Expand Down Expand Up @@ -526,34 +495,6 @@ def _construct_fir_filter(
(windowing is a smoothing in frequency domain).
If x is multi-dimensional, this operates along the last dimension.
Parameters
----------
sfreq : float
Sampling rate in Hz.
freq : 1d array
Frequency sampling points in Hz.
gain : 1d array
Filter gain at frequency sampling points.
Must be all 0 and 1 for fir_design=="firwin".
filter_length : int
Length of the filter to use. Must be odd length if phase == "zero".
phase : str
If 'zero', the delay for the filter is compensated (and it must be
an odd-length symmetric filter). If 'linear', the response is
uncompensated. If 'zero-double', the filter is applied in the
forward and reverse directions. If 'minimum', a minimum-phase
filter will be used.
fir_window : str
The window to use in FIR design, can be "hamming" (default),
"hann", or "blackman".
fir_design : str
Can be "firwin2" or "firwin".
Returns
-------
h : array
Filter coefficients.
"""
assert freq[0] == 0
if fir_design == "firwin2":
Expand All @@ -562,7 +503,7 @@ def _construct_fir_filter(
assert fir_design == "firwin"
fir_design = partial(_firwin_design, sfreq=sfreq)
# issue a warning if attenuation is less than this
min_att_db = 12 if phase == "minimum" else 20
min_att_db = 12 if phase == "minimum-half" else 20

# normalize frequencies
freq = np.array(freq) / (sfreq / 2.0)
Expand All @@ -575,11 +516,13 @@ def _construct_fir_filter(
# Use overlap-add filter with a fixed length
N = _check_zero_phase_length(filter_length, phase, gain[-1])
# construct symmetric (linear phase) filter
if phase == "minimum":
if phase == "minimum-half":
h = fir_design(N * 2 - 1, freq, gain, window=fir_window)
h = signal.minimum_phase(h)
h = minimum_phase(h)
else:
h = fir_design(N, freq, gain, window=fir_window)
if phase == "minimum":
h = minimum_phase(h, half=False)
assert h.size == N
att_db, att_freq = _filter_attenuation(h, freq, gain)
if phase == "zero-double":
Expand Down Expand Up @@ -2162,7 +2105,7 @@ def detrend(x, order=1, axis=-1):
"blackman": dict(name="Blackman", ripple=0.0017, attenuation=74),
}
_known_fir_windows = tuple(sorted(_fir_window_dict.keys()))
_known_phases_fir = ("linear", "zero", "zero-double", "minimum")
_known_phases_fir = ("linear", "zero", "zero-double", "minimum", "minimum-half")
_known_phases_iir = ("zero", "zero-double", "forward")
_known_fir_designs = ("firwin", "firwin2")
_fir_design_dict = {
Expand Down
55 changes: 55 additions & 0 deletions mne/fixes.py
Original file line number Diff line number Diff line change
Expand Up @@ -889,3 +889,58 @@ def _numpy_h5py_dep():
"ignore", "`product` is deprecated.*", DeprecationWarning
)
yield


def minimum_phase(h, method="homomorphic", n_fft=None, *, half=True):
"""Wrap scipy.signal.minimum_phase with half option."""
# Can be removed once
from scipy.fft import fft, ifft
from scipy.signal import minimum_phase as sp_minimum_phase

assert isinstance(method, str) and method == "homomorphic"

if "half" in inspect.getfullargspec(sp_minimum_phase).kwonlyargs:
return sp_minimum_phase(h, method=method, n_fft=n_fft, half=half)
h = np.asarray(h)
if np.iscomplexobj(h):
raise ValueError("Complex filters not supported")
if h.ndim != 1 or h.size <= 2:
raise ValueError("h must be 1-D and at least 2 samples long")
n_half = len(h) // 2
if not np.allclose(h[-n_half:][::-1], h[:n_half]):
warnings.warn(
"h does not appear to by symmetric, conversion may fail",
RuntimeWarning,
stacklevel=2,
)
if n_fft is None:
n_fft = 2 ** int(np.ceil(np.log2(2 * (len(h) - 1) / 0.01)))
n_fft = int(n_fft)
if n_fft < len(h):
raise ValueError("n_fft must be at least len(h)==%s" % len(h))

# zero-pad; calculate the DFT
h_temp = np.abs(fft(h, n_fft))
# take 0.25*log(|H|**2) = 0.5*log(|H|)
h_temp += 1e-7 * h_temp[h_temp > 0].min() # don't let log blow up
np.log(h_temp, out=h_temp)
if half: # halving of magnitude spectrum optional
h_temp *= 0.5
# IDFT
h_temp = ifft(h_temp).real
# multiply pointwise by the homomorphic filter
# lmin[n] = 2u[n] - d[n]
# i.e., double the positive frequencies and zero out the negative ones;
# Oppenheim+Shafer 3rd ed p991 eq13.42b and p1004 fig13.7
win = np.zeros(n_fft)
win[0] = 1
stop = n_fft // 2
win[1:stop] = 2
if n_fft % 2:
win[stop] = 1
h_temp *= win
h_temp = ifft(np.exp(fft(h_temp)))
h_minimum = h_temp.real

n_out = (n_half + len(h) % 2) if half else len(h)
return h_minimum[:n_out]
62 changes: 55 additions & 7 deletions mne/tests/test_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,25 +606,31 @@ def test_filters():
# try new default and old default
freqs = fftfreq(a.shape[-1], 1.0 / sfreq)
A = np.abs(fft(a))
kwargs = dict(fir_design="firwin")
kw = dict(fir_design="firwin")
for fl in ["auto", "10s", "5000ms", 1024, 1023]:
bp = filter_data(a, sfreq, 4, 8, None, fl, 1.0, 1.0, **kwargs)
bs = filter_data(a, sfreq, 8 + 1.0, 4 - 1.0, None, fl, 1.0, 1.0, **kwargs)
lp = filter_data(a, sfreq, None, 8, None, fl, 10, 1.0, n_jobs=2, **kwargs)
hp = filter_data(lp, sfreq, 4, None, None, fl, 1.0, 10, **kwargs)
bp = filter_data(a, sfreq, 4, 8, None, fl, 1.0, 1.0, **kw)
bs = filter_data(a, sfreq, 8 + 1.0, 4 - 1.0, None, fl, 1.0, 1.0, **kw)
lp = filter_data(a, sfreq, None, 8, None, fl, 10, 1.0, n_jobs=2, **kw)
hp = filter_data(lp, sfreq, 4, None, None, fl, 1.0, 10, **kw)
assert_allclose(hp, bp, rtol=1e-3, atol=2e-3)
assert_allclose(bp + bs, a, rtol=1e-3, atol=1e-3)
# Sanity check ttenuation
mask = (freqs > 5.5) & (freqs < 6.5)
assert_allclose(np.mean(np.abs(fft(bp)[:, mask]) / A[:, mask]), 1.0, atol=0.02)
assert_allclose(np.mean(np.abs(fft(bs)[:, mask]) / A[:, mask]), 0.0, atol=0.2)
# now the minimum-phase versions
bp = filter_data(a, sfreq, 4, 8, None, fl, 1.0, 1.0, phase="minimum", **kwargs)
bp = filter_data(a, sfreq, 4, 8, None, fl, 1.0, 1.0, phase="minimum-half", **kw)
bs = filter_data(
a, sfreq, 8 + 1.0, 4 - 1.0, None, fl, 1.0, 1.0, phase="minimum", **kwargs
a, sfreq, 8 + 1.0, 4 - 1.0, None, fl, 1.0, 1.0, phase="minimum-half", **kw
)
assert_allclose(np.mean(np.abs(fft(bp)[:, mask]) / A[:, mask]), 1.0, atol=0.11)
assert_allclose(np.mean(np.abs(fft(bs)[:, mask]) / A[:, mask]), 0.0, atol=0.3)
bp = filter_data(a, sfreq, 4, 8, None, fl, 1.0, 1.0, phase="minimum", **kw)
bs = filter_data(
a, sfreq, 8 + 1.0, 4 - 1.0, None, fl, 1.0, 1.0, phase="minimum", **kw
)
assert_allclose(np.mean(np.abs(fft(bp)[:, mask]) / A[:, mask]), 1.0, atol=0.12)
assert_allclose(np.mean(np.abs(fft(bs)[:, mask]) / A[:, mask]), 0.0, atol=0.27)

# and since these are low-passed, downsampling/upsampling should be close
n_resamp_ignore = 10
Expand Down Expand Up @@ -1050,3 +1056,45 @@ def test_filter_picks():
raw.filter(picks=picks, **kwargs)
want = want[1:]
assert_allclose(raw.get_data(), want)


def test_filter_minimum_phase_bug():
"""Test gh-12267 is fixed."""
sfreq = 1000.0
n_taps = 1001
l_freq = 10.0 # Hz
kwargs = dict(
data=None,
sfreq=sfreq,
l_freq=l_freq,
h_freq=None,
filter_length=n_taps,
l_trans_bandwidth=l_freq / 2.0,
)
h = create_filter(phase="zero", **kwargs)
h_min = create_filter(phase="minimum", **kwargs)
h_min_half = create_filter(phase="minimum-half", **kwargs)
assert h_min.size == h.size
kwargs = dict(worN=10000, fs=sfreq)
w, H = freqz(h, **kwargs)
assert w[0] == 0
dc_dB = 20 * np.log10(np.abs(H[0]))
assert dc_dB < -100
# good
w_min, H_min = freqz(h_min, **kwargs)
assert_allclose(w, w_min)
dc_dB_min = 20 * np.log10(np.abs(H_min[0]))
assert dc_dB_min < -100
mask = w < 5
assert 10 < mask.sum() < 101
assert_allclose(np.abs(H[mask]), np.abs(H_min[mask]), atol=1e-3, rtol=1e-3)
assert_array_less(20 * np.log10(np.abs(H[mask])), -40)
assert_array_less(20 * np.log10(np.abs(H_min[mask])), -40)
# bad
w_min_half, H_min_half = freqz(h_min_half, **kwargs)
assert_allclose(w, w_min_half)
dc_dB_min_half = 20 * np.log10(np.abs(H_min_half[0]))
assert -80 < dc_dB_min_half < 40
dB_min_half = 20 * np.log10(np.abs(H_min_half[mask]))
assert_array_less(dB_min_half, -20)
assert not (dB_min_half < -30).all()
39 changes: 27 additions & 12 deletions mne/utils/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2809,21 +2809,36 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75):
docdict["phase"] = """
phase : str
Phase of the filter.
When ``method='fir'``, symmetric linear-phase FIR filters are constructed,
and if ``phase='zero'`` (default), the delay of this filter is compensated
for, making it non-causal. If ``phase='zero-double'``,
then this filter is applied twice, once forward, and once backward
(also making it non-causal). If ``'minimum'``, then a minimum-phase filter
will be constructed and applied, which is causal but has weaker stop-band
suppression.
When ``method='iir'``, ``phase='zero'`` (default) or
``phase='zero-double'`` constructs and applies IIR filter twice, once
forward, and once backward (making it non-causal) using
:func:`~scipy.signal.filtfilt`.
If ``phase='forward'``, it constructs and applies forward IIR filter using
When ``method='fir'``, symmetric linear-phase FIR filters are constructed
with the following behaviors when ``method="fir"``:
``"zero"`` (default)
The delay of this filter is compensated for, making it non-causal.
``"minimum"``
A minimum-phase filter will be constructed by decomposing the zero-phase filter
into a minimum-phase and all-pass systems, and then retaining only the
minimum-phase system (of the same length as the original zero-phase filter)
via :func:`scipy.signal.minimum_phase`.
``"zero-double"``
*This is a legacy option for compatibility with MNE <= 0.13.*
The filter is applied twice, once forward, and once backward
(also making it non-causal).
``"minimum-half"``
*This is a legacy option for compatibility with MNE <= 1.6.*
A minimum-phase filter will be reconstructed from the zero-phase filter with
half the length of the original filter.
When ``method='iir'``, ``phase='zero'`` (default) or equivalently ``'zero-double'``
constructs and applies IIR filter twice, once forward, and once backward (making it
non-causal) using :func:`~scipy.signal.filtfilt`; ``phase='forward'`` will apply
the filter once in the forward (causal) direction using
:func:`~scipy.signal.lfilter`.
.. versionadded:: 0.13
.. versionchanged:: 1.7
The behavior for ``phase="minimum"`` was fixed to use a filter of the requested
length and improved suppression.
"""

docdict["physical_range_export_params"] = """
Expand Down
Loading

0 comments on commit 4595911

Please sign in to comment.