Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
martinvonk committed Mar 13, 2024
1 parent a443db2 commit ec46452
Show file tree
Hide file tree
Showing 6 changed files with 367 additions and 352 deletions.
687 changes: 347 additions & 340 deletions doc/examples/example02_distributions.ipynb

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion doc/examples/example04_package_comparison.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@
"import pastas as ps\n",
"import pandas as pd\n",
"import scipy.stats as scs\n",
"import matplotlib.pyplot as plt"
"import matplotlib.pyplot as plt\n",
"\n",
"print(si.show_versions())"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion src/spei/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# flake8: noqa
from . import climdex, dist, plot, si, utils
from ._version import __version__, show_versions
from .si import sgi, spei, spi, ssfi, SI
from .si import SI, sgi, spei, spi, ssfi
4 changes: 3 additions & 1 deletion src/spei/dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@ def __post_init__(self):
self.p0 = (data_fit == 0.0).sum() / len(data_fit)

@staticmethod
def fit_dist(data: Series, dist: ContinuousDist) -> Tuple[Optional[List[float]], float, float]:
def fit_dist(
data: Series, dist: ContinuousDist
) -> Tuple[Optional[List[float]], float, float]:
"""
Fits a Scipy continuous distribution to the data.
Expand Down
14 changes: 9 additions & 5 deletions src/spei/si.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,9 @@ def fit_distribution(self):
dfval_window = daily_window_group_yearly_df(
dfval=self._grouped_year, period=period
)
for dfval_rwindow in dfval_window.rolling(window=window, min_periods=window, closed="right"):
for dfval_rwindow in dfval_window.rolling(
window=window, min_periods=window, closed="right"
):
if len(dfval_rwindow) < window:
continue # min_periods ignored by Rolling.__iter__
date = dfval_rwindow.index[period]
Expand All @@ -327,7 +329,9 @@ def fit_distribution(self):
self._dist_dict[date] = fd
else:
logging.info("Using groupby fit by frequency method")
for date, grval in self._grouped_year.groupby(Grouper(freq=self.fit_freq)):
for date, grval in self._grouped_year.groupby(
Grouper(freq=str(self.fit_freq))
):
data = get_data_series(grval)
fd = Dist(
data=data,
Expand Down Expand Up @@ -370,7 +374,7 @@ def cdf_nsf(self) -> Series:
"""
logging.info("Using the normal scores transform")
cdf = Series(nan, index=self.series.index, dtype=float)
for _, grval in self._grouped_year.groupby(Grouper(freq=self.fit_freq)):
for _, grval in self._grouped_year.groupby(Grouper(freq=str(self.fit_freq))):
data = get_data_series(grval).sort_values()
n = len(data)
cdf.loc[data.index] = linspace(1 / (2 * n), 1 - 1 / (2 * n), n)
Expand All @@ -397,5 +401,5 @@ def get_dist(self, date: Timestamp) -> Dist:
dist = self._dist_dict[k]
if date in dist.data.index:
return dist
else:
raise KeyError("Date not found in distributions")

raise KeyError("Date not found in distributions")
8 changes: 4 additions & 4 deletions src/spei/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,10 @@ def daily_window_group_yearly_df(dfval: DataFrame, period: int) -> DataFrame:
"""
dfval_window_index_start = [
dfval.index[0] + Timedelta(value=-i, unit="D")
for i in reversed(range(1, period+1))
for i in reversed(range(1, period + 1))
]
dfval_window_index_end = [
dfval.index[-1] + Timedelta(value=i, unit="D") for i in range(1, period+1)
dfval.index[-1] + Timedelta(value=i, unit="D") for i in range(1, period + 1)
]
dfval_window_index = DatetimeIndex(
dfval_window_index_start + dfval.index.to_list() + dfval_window_index_end
Expand All @@ -132,6 +132,6 @@ def daily_window_group_yearly_df(dfval: DataFrame, period: int) -> DataFrame:
nan, index=dfval_window_index, columns=dfval.columns, dtype=float
)
dfval_window.loc[dfval.index, dfval.columns] = dfval.values
dfval_window.iloc[: period] = dfval.iloc[-period :].values
dfval_window.iloc[-period :] = dfval.iloc[: period].values
dfval_window.iloc[:period] = dfval.iloc[-period:].values
dfval_window.iloc[-period:] = dfval.iloc[:period].values
return dfval_window

0 comments on commit ec46452

Please sign in to comment.