Skip to content

Commit

Permalink
Allow dask (#181)
Browse files Browse the repository at this point in the history
* added dask='allowed', and solved floating point difference in array_equal test_correlation and test_correlation_with_dim

* added dask='allowed', and solved floating point difference in array_equal test_correlation and test_correlation_with_dim

* Added allow_rechuck is True

* allow rechuck is False

* perhaps should move towards da.polyfit() to support dask.arrays

* Fix linear trend so that tests do not fail on chunked data

* Fix formatting

* Fix remaining typing errors

* Remove commented out code and more formatting

* Update changelog and citation information

* Remove dask as dependency

* Add dask back in

* Fix citation and changelog.

---------

Co-authored-by: semvijverberg <[email protected]>
  • Loading branch information
ClaireDons and semvijverberg authored Sep 27, 2024
1 parent caa43e3 commit 3a67444
Show file tree
Hide file tree
Showing 11 changed files with 45 additions and 16 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/).

## [Unreleased]

## Added
- Dask support ([#181] (https://github.com/AI4S2S/s2spy/pull/181))

## Changed

## 0.4.0 (2023-09-13)

### Added
Expand Down
6 changes: 6 additions & 0 deletions CITATION.cff
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ authors:
given-names: Jannes
affiliation: "Vrije Universiteit Amsterdam"

-
affiliation: "Netherlands eScience Center"
family-names: Donnelly
given-names: Claire
orcid: https://orcid.org/0000-0002-2546-4528

date-released: 2022-09-02
version: "0.4.0"
repository-code: "https://github.com/AI4S2S/s2spy"
Expand Down
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,15 @@ license = "Apache-2.0"
requires-python = ">3.8,<3.12"
authors = [
{email = "[email protected]"},
{name = "Yang Liu, Bart Schilperoort, Peter Kalverla, Jannes van Ingen, Sem Vijverberg"}
{name = "Yang Liu, Bart Schilperoort, Peter Kalverla, Jannes van Ingen, Sem Vijverberg, Claire Donnelly"}
]
maintainers = [
{name = "Yang Liu", email = "[email protected]"},
{name = "Bart Schilperoort", email = "[email protected]"},
{name = "Peter Kalverla", email = "[email protected]"},
{name = "Jannes van Ingen", email = "[email protected]"},
{name = "Sem Vijverberg", email = "[email protected]"},
{name = "Claire Donnelly", email = "[email protected]"},
]
keywords = [
"AI",
Expand All @@ -35,6 +36,7 @@ classifiers = [
"Programming Language :: Python :: 3.11",
]
dependencies = [
"dask",
"lilio",
"matplotlib",
"netcdf4",
Expand Down
1 change: 1 addition & 0 deletions s2spy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
This package is a high-level python package integrating expert knowledge
and artificial intelligence to boost (sub) seasonal forecasting.
"""

import logging
from .rgdr.rgdr import RGDR

Expand Down
9 changes: 7 additions & 2 deletions s2spy/preprocess.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
"""Preprocessor for s2spy workflow."""

import warnings
from typing import Literal
from typing import Union
import numpy as np
import scipy.stats
import scipy
import xarray as xr


Expand Down Expand Up @@ -39,7 +40,9 @@ def _trend_linear(data: Union[xr.DataArray, xr.Dataset]) -> dict:
input_core_dims=[["time"], ["time"]],
output_core_dims=[[], []],
vectorize=True,
dask="parallelized",
)

return {"slope": slope, "intercept": intercept}


Expand Down Expand Up @@ -142,7 +145,9 @@ def _check_data_resolution_match(
"daily": np.timedelta64(1, "D"),
}
time_intervals = np.diff(data["time"].to_numpy())
temporal_resolution = np.median(time_intervals).astype("timedelta64[D]")
temporal_resolution: np.timedelta64 = np.median(time_intervals).astype(
"timedelta64[D]"
)
if timescale == "monthly":
temporal_resolution = temporal_resolution.astype(int)
min_days, max_days = (28, 31)
Expand Down
1 change: 1 addition & 0 deletions s2spy/rgdr/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
"""Response Guided Dimensionality Reduction."""

from . import label_alignment # noqa: F401 (unused import)
3 changes: 2 additions & 1 deletion s2spy/rgdr/label_alignment.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Label alignment tools for RGDR clusters."""

import itertools
import string
from copy import copy
Expand Down Expand Up @@ -102,7 +103,7 @@ def _calculate_overlap(
mask_a = xr.where(cluster_labels.sel(split=split_a) == cluster_a, 1, 0).values
mask_b = xr.where(cluster_labels.sel(split=split_b) == cluster_b, 1, 0).values

return np.sum(np.logical_and(mask_a, mask_b)) / np.sum(mask_a)
return np.sum(np.logical_and(mask_a, mask_b)) / np.sum(mask_a) # type: ignore


def calculate_overlap_table(cluster_labels: xr.DataArray) -> pd.DataFrame:
Expand Down
14 changes: 8 additions & 6 deletions s2spy/rgdr/rgdr.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Response Guided Dimensionality Reduction."""

import warnings
from os import linesep
from typing import Optional
Expand Down Expand Up @@ -281,6 +282,7 @@ def correlation(
field,
target,
input_core_dims=[[corr_dim], [corr_dim]],
dask="allowed",
vectorize=True,
output_core_dims=[[], []],
)
Expand Down Expand Up @@ -606,12 +608,12 @@ def transform(self, data: xr.DataArray) -> xr.DataArray:
# Add the geographical centers for later alignment between, e.g., splits
reduced_data = utils.geographical_cluster_center(data, reduced_data)
# Include explanations about geographical centers as attributes
reduced_data.attrs[
"data"
] = "Clustered data with Response Guided Dimensionality Reduction."
reduced_data.attrs[
"coordinates"
] = "Latitudes and longitudes are geographical centers associated with clusters."
reduced_data.attrs["data"] = (
"Clustered data with Response Guided Dimensionality Reduction."
)
reduced_data.attrs["coordinates"] = (
"Latitudes and longitudes are geographical centers associated with clusters."
)

# Remove the '0' cluster
reduced_data = reduced_data.where(reduced_data["cluster_labels"] != 0).dropna(
Expand Down
1 change: 1 addition & 0 deletions s2spy/rgdr/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Commonly used utility functions for s2spy."""

from typing import TypeVar
import numpy as np
import xarray as xr
Expand Down
8 changes: 6 additions & 2 deletions tests/test_preprocess.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Tests for the s2spy.preprocess module."""

import numpy as np
import pytest
import scipy.signal
Expand All @@ -21,13 +22,15 @@ class TestPreprocessMethods:
# Define inputs as fixtures
@pytest.fixture
def raw_field(self):
return xr.open_dataset(
f"{TEST_FILE_PATH}/sst_daily_1979-2018_5deg_Pacific_175_240E_25_50N.nc"
data = xr.open_dataset(
f"{TEST_FILE_PATH}/sst_daily_1979-2018_5deg_Pacific_175_240E_25_50N.nc",
chunks="auto",
).sel(
time=slice("2010-01-01", "2011-12-31"),
latitude=slice(40, 30),
longitude=slice(180, 190),
)
return data

def test_check_input_data_incorrect_type(self):
dummy_data = np.ones((3, 3))
Expand All @@ -49,6 +52,7 @@ def test_get_and_subtract_linear_trend(self, raw_field):
raw_field,
input_core_dims=[["time"]],
output_core_dims=[["time"]],
dask="parallelized",
).transpose("time", ...)
trend = preprocess._get_trend(raw_field, "linear")
result = preprocess._subtract_trend(raw_field, "linear", trend)
Expand Down
9 changes: 5 additions & 4 deletions tests/test_rgdr/test_rgdr.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Tests for the s2s.rgdr module."""

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
Expand Down Expand Up @@ -141,16 +142,16 @@ def test_pearsonr_nan(self):
def test_correlation(self, dummy_dataarray, dummy_timeseries):
c_val, p_val = rgdr.correlation(dummy_dataarray, dummy_timeseries)

np.testing.assert_equal(c_val.values, 1)
np.testing.assert_equal(p_val.values, 0)
np.testing.assert_almost_equal(c_val.values, 1, decimal=5)
np.testing.assert_almost_equal(p_val.values, 0, decimal=5)

def test_correlation_dim_name(self, dummy_dataarray, dummy_timeseries):
da = dummy_dataarray.rename({"time": "i_interval"})
ts = dummy_timeseries.rename({"time": "i_interval"})
c_val, p_val = rgdr.correlation(da, ts, corr_dim="i_interval")

np.testing.assert_equal(c_val.values, 1)
np.testing.assert_equal(p_val.values, 0)
np.testing.assert_almost_equal(c_val.values, 1, decimal=5)
np.testing.assert_almost_equal(p_val.values, 0, decimal=5)

def test_correlation_wrong_target_dim_name(self, dummy_dataarray, dummy_timeseries):
ts = dummy_timeseries.rename({"time": "dummy"})
Expand Down

0 comments on commit 3a67444

Please sign in to comment.