diff --git a/oceans/filters.py b/oceans/filters.py index e053b3c..894e361 100644 --- a/oceans/filters.py +++ b/oceans/filters.py @@ -415,10 +415,7 @@ def medfilt1(x, L=3): >>> L = 103 >>> xout = medfilt1(x=x, L=L) >>> ax = plt.subplot(212) - >>> ( - ... l1, - ... l2, - ... ) = ax.plot( + >>> (l1, l2,) = ax.plot( ... x ... ), ax.plot(xout) >>> ax.grid(True) @@ -570,7 +567,7 @@ def md_trenberth(x): return y -def pl33tn(x, dt=1.0, T=33.0, mode="valid"): +def pl33tn(x, dt=1.0, T=33.0, mode="valid", t=None): """ Computes low-passed series from `x` using pl33 filter, with optional sample interval `dt` (hours) and filter half-amplitude period T (hours) @@ -608,14 +605,25 @@ def pl33tn(x, dt=1.0, T=33.0, mode="valid"): """ import cf_xarray # noqa: F401 + import pandas as pd import xarray as xr - if isinstance(x, xr.Dataset): - raise TypeError("Input a DataArray not a Dataset.") + if isinstance(x, (xr.Dataset, pd.DataFrame)): + raise TypeError("Input a DataArray not a Dataset, or a Series not a DataFrame.") + if isinstance(x, pd.Series) and not isinstance( + x.index, + pd.core.indexes.datetimes.DatetimeIndex, + ): + raise TypeError("Input Series needs to have parsed datetime indices.") + + # find dt in units of hours if isinstance(x, xr.DataArray): - # find dt in units of hours - dt = (x.cf["T"][1] - x.cf["T"][0]) * 1e-9 / 3600 + dt = (x.cf["T"][1] - x.cf["T"][0]) / np.timedelta64( + 360_000_000_000, + ) + elif isinstance(x, pd.Series): + dt = (x.index[1] - x.index[0]) / pd.Timedelta("1H") pl33 = np.array( [ @@ -694,18 +702,20 @@ def pl33tn(x, dt=1.0, T=33.0, mode="valid"): dt = float(dt) * (33.0 / T) filter_time = np.arange(0.0, 33.0, dt, dtype="d") - # N = len(filter_time) + Nt = len(filter_time) filter_time = np.hstack((-filter_time[-1:0:-1], filter_time)) pl33 = np.interp(filter_time, _dt, pl33) pl33 /= pl33.sum() if isinstance(x, xr.DataArray): + x = x.interpolate_na(dim=x.cf["T"].name) + weight = xr.DataArray(pl33, dims=["window"]) xf = ( x.rolling({x.cf["T"].name: len(pl33)}, center=True) .construct({x.cf["T"].name: "window"}) - .dot(weight) + .dot(weight, dims="window") ) # update attrs attrs = { @@ -715,7 +725,26 @@ def pl33tn(x, dt=1.0, T=33.0, mode="valid"): } xf.attrs = attrs + elif isinstance(x, pd.Series): + xf = x.to_frame().apply(np.convolve, v=pl33, mode=mode) + + # nan out edges which are not good values anyway + if mode == "same": + xf[: Nt - 1] = np.nan + xf[-Nt + 2 :] = np.nan + else: # use numpy xf = np.convolve(x, pl33, mode=mode) + # times to match xf + if t is not None: + # Nt = len(filter_time) + tf = t[Nt - 1 : -Nt + 1] + return xf, tf + + # nan out edges which are not good values anyway + if mode == "same": + xf[: Nt - 1] = np.nan + xf[-Nt + 2 :] = np.nan + return xf