Skip to content

Commit

Permalink
Functionality to batch remove trends from dataset dictionary (#155)
Browse files Browse the repository at this point in the history
* added match and detrend

* Added whats new entry [skip-ci]
  • Loading branch information
jbusecke authored Jul 8, 2021
1 parent 7cfc57f commit 9fad5d6
Show file tree
Hide file tree
Showing 5 changed files with 219 additions and 2 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,4 @@ target/
.vscode
.mypy_cache
readthedocs.yml
cmip6_preprocessing/_version.py
53 changes: 53 additions & 0 deletions cmip6_preprocessing/drift_removal.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from xarrayutils.utils import linear_trend

from cmip6_preprocessing.postprocessing import _match_datasets, exact_attrs
from cmip6_preprocessing.utils import cmip6_dataset_id


Expand Down Expand Up @@ -369,3 +370,55 @@ def remove_trend(ds, ds_slope, variable, ref_date, check_mask=True):
] = f"linear_trend_{cmip6_dataset_id(ds_slope)}_{trend_start}_{trend_stop}"

return detrended


def match_and_remove_trend(
ddict, trend_ddict, ref_date="1850", nomatch="warn", **detrend_kwargs
):
"""Find and remove trend files from a dictonary of datasets
Parameters
----------
ddict : dict
dictionary with xr.Datasets which should get a trend/drift removed
trend_ddict : dict
dictionary with results of linear regressions. These should be removed from the datasets in `ddict`
ref_date : str, optional
Start date of the trend, by default "1850"
nomatch : str, optional
Define the behavior when for a given dataset in `ddict` there is no matching trend dataset in `trend_ddict`.
Can be `warn`, `raise`, or `ignore`, by default 'warn'
Returns
-------
dict
Dictionary of detrended dataasets. Only contains values of `ddict` that actually had a trend removed.
"""
ddict_detrended = {}
match_attrs = [ma for ma in exact_attrs if ma not in ["experiment_id"]] + [
"variable_id"
]

for k, ds in ddict.items():
trend_ds = _match_datasets(
ds, trend_ddict, match_attrs, pop=False, unique=True, nomatch=nomatch
)
# # print(trend_ds)
if len(trend_ds) == 2:
trend_ds = trend_ds[
1
] # this is a bit clunky. _match_datasest does return the input ds, so we have to grab the second one?
# I guess I could pass *trend_ds, but that is not very readable
variable = ds.attrs["variable_id"]
da_detrended = ds.assign(
{
variable: remove_trend(
ds, trend_ds, variable, ref_date=ref_date, **detrend_kwargs
)
}
)
# should this just return a dataset instead?
ddict_detrended[k] = da_detrended

return ddict_detrended
20 changes: 19 additions & 1 deletion cmip6_preprocessing/postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def _match_attrs(ds_a, ds_b, match_attrs):
return sum([ds_a.attrs[i] == ds_b.attrs[i] for i in match_attrs])


def _match_datasets(ds, ds_dict, match_attrs, pop=True):
def _match_datasets(ds, ds_dict, match_attrs, pop=True, nomatch="ignore", unique=False):
"""Find all datasets in a dictionary of datasets that have a matching set
of attributes and return a list of datasets for merging/concatting.
Optionally remove the matching datasets from the input dict.
Expand All @@ -41,6 +41,24 @@ def _match_datasets(ds, ds_dict, match_attrs, pop=True):
# preserve the original dictionary key of the chosen dataset in attribute.
ds_matched.attrs["original_key"] = k
datasets.append(ds_matched)
if len(datasets) > 2:
if unique:
raise ValueError(
f"Found more than one matching dataset for {cmip6_dataset_id(ds)}. Pass `unique=False` to ignore this."
)
nomatch_msg = f"Could not find a matching dataset for {cmip6_dataset_id(ds)}"
if len(datasets) < 2:
if nomatch == "ignore":
pass
elif nomatch == "warn":
warnings.warn(nomatch_msg)
elif nomatch == "raise":
raise RuntimeError(nomatch_msg)
else:
# Could this be done in an annotation? Or some other built in way to do this?
raise ValueError(
f"Invalid input ({nomatch}) for `nomatch`, should be `ignore`, `warn`, or `raise`"
)
return datasets


Expand Down
13 changes: 12 additions & 1 deletion docs/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,20 @@

What's New
===========
.. _whats-new.0.5.0:

v0.5.0 (unreleased)
-------------------

New Features
~~~~~~~~~~~~

- :py:meth:`~cmip6_preprocessing.drift_removal.match_and_remove_drift` enables batch detrending/drift-drift_removal
from a dictionary of datasets (:pull:`155`). By `Julius Busecke <https://github.com/jbusecke>`_

.. _whats-new.0.4.0:

v0.4.0 (unreleased)
v0.4.0 (2021/6/9)
-------------------

New Features
Expand Down
134 changes: 134 additions & 0 deletions tests/test_drift_removal.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@
_construct_cfdate,
calculate_drift,
find_date_idx,
match_and_remove_trend,
remove_trend,
replace_time,
unify_time,
)
from cmip6_preprocessing.postprocessing import exact_attrs


# I copied this from a PR I made to parcels a while back.
Expand Down Expand Up @@ -527,3 +529,135 @@ def test_calculate_drift_exceptions_partial():
with pytest.warns(UserWarning) as winfo:
reg = calculate_drift(ds_control, ds, "test", compute_short_trends=True)
assert "years to calculate trend. Using 1 years only" in winfo[0].message.args[0]


@pytest.mark.parametrize("ref_date", ["1850", "2000-01-02"])
def test_match_and_remove_trend_matching_experiment(ref_date):
# construct a dict of data to be detrended
nx, ny = (10, 20)
nt = 24
time_historical = xr.cftime_range("1850-01-01", periods=nt, freq="1MS")
time_ssp = xr.cftime_range("2014-01-01", periods=nt, freq="1MS")
raw_attrs = {k: "dummy" for k in exact_attrs + ["variable_id"]}

ds_a_hist_vara = xr.DataArray(
np.random.rand(nx, ny, nt),
dims=["x", "y", "time"],
coords={"time": time_historical},
).to_dataset(name="vara")
ds_a_hist_vara.attrs = {k: v for k, v in raw_attrs.items()}
ds_a_hist_vara.attrs["variant_label"] = "a"
ds_a_hist_vara.attrs["variable_id"] = "vara"
ds_a_hist_vara.attrs["experiment_id"] = "historical"

ds_a_hist_varb = xr.DataArray(
np.random.rand(nx, ny, nt),
dims=["x", "y", "time"],
coords={"time": time_historical},
).to_dataset(name="varb")
ds_a_hist_varb.attrs = {k: v for k, v in raw_attrs.items()}
ds_a_hist_varb.attrs["variant_label"] = "a"
ds_a_hist_varb.attrs["variable_id"] = "varb"
ds_a_hist_varb.attrs["experiment_id"] = "historical"
ds_a_other_vara = xr.DataArray(
np.random.rand(nx, ny, nt),
dims=["x", "y", "time"],
coords={"time": time_ssp},
).to_dataset(name="vara")
ds_a_other_vara.attrs = {k: v for k, v in raw_attrs.items()}
ds_a_other_vara.attrs["variant_label"] = "a"
ds_a_other_vara.attrs["variable_id"] = "vara"
ds_a_other_vara.attrs["experiment_id"] = "other"

ds_b_hist_vara = xr.DataArray(
np.random.rand(nx, ny, nt),
dims=["x", "y", "time"],
coords={"time": time_historical},
).to_dataset(name="vara")
ds_b_hist_vara.attrs = {k: v for k, v in raw_attrs.items()}
ds_b_hist_vara.attrs["variant_label"] = "b"
ds_b_hist_vara.attrs["variable_id"] = "vara"
ds_b_hist_vara.attrs["experiment_id"] = "historical"

ddict = {
"ds_a_hist_vara": ds_a_hist_vara,
"ds_a_hist_varb": ds_a_hist_varb,
"ds_a_other_vara": ds_a_other_vara,
"ds_b_hist_vara": ds_b_hist_vara,
}

trend_a_vara = (
xr.ones_like(ds_a_hist_vara.isel(time=0)).drop_vars("time") * np.random.rand()
)
trend_a_varb = (
xr.ones_like(ds_a_hist_varb.isel(time=0)).drop_vars("time") * np.random.rand()
)

trend_b_vara = (
xr.ones_like(ds_b_hist_vara.isel(time=0)).drop_vars("time") * np.random.rand()
)
# trend_b_varb = (
# xr.ones_like(ds_b_hist_varb.isel(time=0)).drop_vars("time") * np.random.rand()
# )
# print(trend_b_varb)

ddict_trend = {
n: ds for n, ds in enumerate([trend_a_vara, trend_b_vara, trend_a_varb])
}

ddict_detrended = match_and_remove_trend(ddict, ddict_trend, ref_date=ref_date)

for name, ds, trend_ds in [
("ds_a_hist_vara", ds_a_hist_vara, trend_a_vara),
("ds_b_hist_vara", ds_b_hist_vara, trend_b_vara),
("ds_a_hist_varb", ds_a_hist_varb, trend_a_varb),
("ds_a_other_vara", ds_a_other_vara, trend_a_vara),
]:
variable = ds.attrs["variable_id"]
expected = remove_trend(ds, trend_ds, variable, ref_date).to_dataset(
name=variable
)
xr.testing.assert_allclose(
ddict_detrended[name],
expected,
)


def test_match_and_remove_trend_nomatch():
# create two datasets that do not match (according to the hardcoded conventions in `match_and_detrend`)
attrs = {}
ds = xr.DataArray().to_dataset(name="test")
ds.attrs = {k: "a" for k in exact_attrs + ["variable_id"]}
ds_nomatch = xr.DataArray().to_dataset(name="test")
ds_nomatch.attrs = {k: "b" for k in exact_attrs + ["variable_id"]}

detrended = match_and_remove_trend({"aa": ds}, {"bb": ds_nomatch}, nomatch="ignore")
assert detrended == {}

match_msg = "Could not find a matching dataset for *"
with pytest.warns(UserWarning, match=match_msg):
detrended = match_and_remove_trend(
{"aa": ds}, {"bb": ds_nomatch}, nomatch="warn"
)

with pytest.raises(RuntimeError, match=match_msg):
detrended = match_and_remove_trend(
{"aa": ds}, {"bb": ds_nomatch}, nomatch="raise"
)


def test_match_and_remove_trend_nonunique():
# create two datasets that do not match (according to the hardcoded conventions in `match_and_detrend`)
attrs = {}
ds = xr.DataArray().to_dataset(name="test")
ds.attrs = {k: "a" for k in exact_attrs + ["variable_id"]}
ds_match_a = xr.DataArray().to_dataset(name="test")
ds_match_b = xr.DataArray().to_dataset(name="test")
ds_match_a.attrs = ds.attrs
ds_match_b.attrs = ds.attrs

match_msg = "Found more than one matching dataset for *"
with pytest.raises(ValueError, match=match_msg):
detrended = match_and_remove_trend(
{"aa": ds}, {"bb": ds_match_a, "cc": ds_match_b}
)

0 comments on commit 9fad5d6

Please sign in to comment.