From 9fad5d6d033ae7113c2755dd3d1679e87f9eb4fb Mon Sep 17 00:00:00 2001 From: Julius Busecke Date: Thu, 8 Jul 2021 12:39:10 -0400 Subject: [PATCH] Functionality to batch remove trends from dataset dictionary (#155) * added match and detrend * Added whats new entry [skip-ci] --- .gitignore | 1 + cmip6_preprocessing/drift_removal.py | 53 ++++++++++ cmip6_preprocessing/postprocessing.py | 20 +++- docs/whats-new.rst | 13 ++- tests/test_drift_removal.py | 134 ++++++++++++++++++++++++++ 5 files changed, 219 insertions(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index a1137391..687856bc 100644 --- a/.gitignore +++ b/.gitignore @@ -69,3 +69,4 @@ target/ .vscode .mypy_cache readthedocs.yml +cmip6_preprocessing/_version.py diff --git a/cmip6_preprocessing/drift_removal.py b/cmip6_preprocessing/drift_removal.py index b1cd0f28..6c74738b 100644 --- a/cmip6_preprocessing/drift_removal.py +++ b/cmip6_preprocessing/drift_removal.py @@ -7,6 +7,7 @@ from xarrayutils.utils import linear_trend +from cmip6_preprocessing.postprocessing import _match_datasets, exact_attrs from cmip6_preprocessing.utils import cmip6_dataset_id @@ -369,3 +370,55 @@ def remove_trend(ds, ds_slope, variable, ref_date, check_mask=True): ] = f"linear_trend_{cmip6_dataset_id(ds_slope)}_{trend_start}_{trend_stop}" return detrended + + +def match_and_remove_trend( + ddict, trend_ddict, ref_date="1850", nomatch="warn", **detrend_kwargs +): + """Find and remove trend files from a dictonary of datasets + + Parameters + ---------- + ddict : dict + dictionary with xr.Datasets which should get a trend/drift removed + trend_ddict : dict + dictionary with results of linear regressions. These should be removed from the datasets in `ddict` + ref_date : str, optional + Start date of the trend, by default "1850" + nomatch : str, optional + Define the behavior when for a given dataset in `ddict` there is no matching trend dataset in `trend_ddict`. + Can be `warn`, `raise`, or `ignore`, by default 'warn' + + Returns + ------- + dict + Dictionary of detrended dataasets. Only contains values of `ddict` that actually had a trend removed. + + """ + ddict_detrended = {} + match_attrs = [ma for ma in exact_attrs if ma not in ["experiment_id"]] + [ + "variable_id" + ] + + for k, ds in ddict.items(): + trend_ds = _match_datasets( + ds, trend_ddict, match_attrs, pop=False, unique=True, nomatch=nomatch + ) + # # print(trend_ds) + if len(trend_ds) == 2: + trend_ds = trend_ds[ + 1 + ] # this is a bit clunky. _match_datasest does return the input ds, so we have to grab the second one? + # I guess I could pass *trend_ds, but that is not very readable + variable = ds.attrs["variable_id"] + da_detrended = ds.assign( + { + variable: remove_trend( + ds, trend_ds, variable, ref_date=ref_date, **detrend_kwargs + ) + } + ) + # should this just return a dataset instead? + ddict_detrended[k] = da_detrended + + return ddict_detrended diff --git a/cmip6_preprocessing/postprocessing.py b/cmip6_preprocessing/postprocessing.py index 17bb6e1c..997f768c 100644 --- a/cmip6_preprocessing/postprocessing.py +++ b/cmip6_preprocessing/postprocessing.py @@ -25,7 +25,7 @@ def _match_attrs(ds_a, ds_b, match_attrs): return sum([ds_a.attrs[i] == ds_b.attrs[i] for i in match_attrs]) -def _match_datasets(ds, ds_dict, match_attrs, pop=True): +def _match_datasets(ds, ds_dict, match_attrs, pop=True, nomatch="ignore", unique=False): """Find all datasets in a dictionary of datasets that have a matching set of attributes and return a list of datasets for merging/concatting. Optionally remove the matching datasets from the input dict. @@ -41,6 +41,24 @@ def _match_datasets(ds, ds_dict, match_attrs, pop=True): # preserve the original dictionary key of the chosen dataset in attribute. ds_matched.attrs["original_key"] = k datasets.append(ds_matched) + if len(datasets) > 2: + if unique: + raise ValueError( + f"Found more than one matching dataset for {cmip6_dataset_id(ds)}. Pass `unique=False` to ignore this." + ) + nomatch_msg = f"Could not find a matching dataset for {cmip6_dataset_id(ds)}" + if len(datasets) < 2: + if nomatch == "ignore": + pass + elif nomatch == "warn": + warnings.warn(nomatch_msg) + elif nomatch == "raise": + raise RuntimeError(nomatch_msg) + else: + # Could this be done in an annotation? Or some other built in way to do this? + raise ValueError( + f"Invalid input ({nomatch}) for `nomatch`, should be `ignore`, `warn`, or `raise`" + ) return datasets diff --git a/docs/whats-new.rst b/docs/whats-new.rst index f8913566..8239af8c 100644 --- a/docs/whats-new.rst +++ b/docs/whats-new.rst @@ -2,9 +2,20 @@ What's New =========== +.. _whats-new.0.5.0: + +v0.5.0 (unreleased) +------------------- + +New Features +~~~~~~~~~~~~ + +- :py:meth:`~cmip6_preprocessing.drift_removal.match_and_remove_drift` enables batch detrending/drift-drift_removal +from a dictionary of datasets (:pull:`155`). By `Julius Busecke `_ + .. _whats-new.0.4.0: -v0.4.0 (unreleased) +v0.4.0 (2021/6/9) ------------------- New Features diff --git a/tests/test_drift_removal.py b/tests/test_drift_removal.py index 28abe4e6..101f6ff2 100644 --- a/tests/test_drift_removal.py +++ b/tests/test_drift_removal.py @@ -9,10 +9,12 @@ _construct_cfdate, calculate_drift, find_date_idx, + match_and_remove_trend, remove_trend, replace_time, unify_time, ) +from cmip6_preprocessing.postprocessing import exact_attrs # I copied this from a PR I made to parcels a while back. @@ -527,3 +529,135 @@ def test_calculate_drift_exceptions_partial(): with pytest.warns(UserWarning) as winfo: reg = calculate_drift(ds_control, ds, "test", compute_short_trends=True) assert "years to calculate trend. Using 1 years only" in winfo[0].message.args[0] + + +@pytest.mark.parametrize("ref_date", ["1850", "2000-01-02"]) +def test_match_and_remove_trend_matching_experiment(ref_date): + # construct a dict of data to be detrended + nx, ny = (10, 20) + nt = 24 + time_historical = xr.cftime_range("1850-01-01", periods=nt, freq="1MS") + time_ssp = xr.cftime_range("2014-01-01", periods=nt, freq="1MS") + raw_attrs = {k: "dummy" for k in exact_attrs + ["variable_id"]} + + ds_a_hist_vara = xr.DataArray( + np.random.rand(nx, ny, nt), + dims=["x", "y", "time"], + coords={"time": time_historical}, + ).to_dataset(name="vara") + ds_a_hist_vara.attrs = {k: v for k, v in raw_attrs.items()} + ds_a_hist_vara.attrs["variant_label"] = "a" + ds_a_hist_vara.attrs["variable_id"] = "vara" + ds_a_hist_vara.attrs["experiment_id"] = "historical" + + ds_a_hist_varb = xr.DataArray( + np.random.rand(nx, ny, nt), + dims=["x", "y", "time"], + coords={"time": time_historical}, + ).to_dataset(name="varb") + ds_a_hist_varb.attrs = {k: v for k, v in raw_attrs.items()} + ds_a_hist_varb.attrs["variant_label"] = "a" + ds_a_hist_varb.attrs["variable_id"] = "varb" + ds_a_hist_varb.attrs["experiment_id"] = "historical" + ds_a_other_vara = xr.DataArray( + np.random.rand(nx, ny, nt), + dims=["x", "y", "time"], + coords={"time": time_ssp}, + ).to_dataset(name="vara") + ds_a_other_vara.attrs = {k: v for k, v in raw_attrs.items()} + ds_a_other_vara.attrs["variant_label"] = "a" + ds_a_other_vara.attrs["variable_id"] = "vara" + ds_a_other_vara.attrs["experiment_id"] = "other" + + ds_b_hist_vara = xr.DataArray( + np.random.rand(nx, ny, nt), + dims=["x", "y", "time"], + coords={"time": time_historical}, + ).to_dataset(name="vara") + ds_b_hist_vara.attrs = {k: v for k, v in raw_attrs.items()} + ds_b_hist_vara.attrs["variant_label"] = "b" + ds_b_hist_vara.attrs["variable_id"] = "vara" + ds_b_hist_vara.attrs["experiment_id"] = "historical" + + ddict = { + "ds_a_hist_vara": ds_a_hist_vara, + "ds_a_hist_varb": ds_a_hist_varb, + "ds_a_other_vara": ds_a_other_vara, + "ds_b_hist_vara": ds_b_hist_vara, + } + + trend_a_vara = ( + xr.ones_like(ds_a_hist_vara.isel(time=0)).drop_vars("time") * np.random.rand() + ) + trend_a_varb = ( + xr.ones_like(ds_a_hist_varb.isel(time=0)).drop_vars("time") * np.random.rand() + ) + + trend_b_vara = ( + xr.ones_like(ds_b_hist_vara.isel(time=0)).drop_vars("time") * np.random.rand() + ) + # trend_b_varb = ( + # xr.ones_like(ds_b_hist_varb.isel(time=0)).drop_vars("time") * np.random.rand() + # ) + # print(trend_b_varb) + + ddict_trend = { + n: ds for n, ds in enumerate([trend_a_vara, trend_b_vara, trend_a_varb]) + } + + ddict_detrended = match_and_remove_trend(ddict, ddict_trend, ref_date=ref_date) + + for name, ds, trend_ds in [ + ("ds_a_hist_vara", ds_a_hist_vara, trend_a_vara), + ("ds_b_hist_vara", ds_b_hist_vara, trend_b_vara), + ("ds_a_hist_varb", ds_a_hist_varb, trend_a_varb), + ("ds_a_other_vara", ds_a_other_vara, trend_a_vara), + ]: + variable = ds.attrs["variable_id"] + expected = remove_trend(ds, trend_ds, variable, ref_date).to_dataset( + name=variable + ) + xr.testing.assert_allclose( + ddict_detrended[name], + expected, + ) + + +def test_match_and_remove_trend_nomatch(): + # create two datasets that do not match (according to the hardcoded conventions in `match_and_detrend`) + attrs = {} + ds = xr.DataArray().to_dataset(name="test") + ds.attrs = {k: "a" for k in exact_attrs + ["variable_id"]} + ds_nomatch = xr.DataArray().to_dataset(name="test") + ds_nomatch.attrs = {k: "b" for k in exact_attrs + ["variable_id"]} + + detrended = match_and_remove_trend({"aa": ds}, {"bb": ds_nomatch}, nomatch="ignore") + assert detrended == {} + + match_msg = "Could not find a matching dataset for *" + with pytest.warns(UserWarning, match=match_msg): + detrended = match_and_remove_trend( + {"aa": ds}, {"bb": ds_nomatch}, nomatch="warn" + ) + + with pytest.raises(RuntimeError, match=match_msg): + detrended = match_and_remove_trend( + {"aa": ds}, {"bb": ds_nomatch}, nomatch="raise" + ) + + +def test_match_and_remove_trend_nonunique(): + # create two datasets that do not match (according to the hardcoded conventions in `match_and_detrend`) + attrs = {} + ds = xr.DataArray().to_dataset(name="test") + ds.attrs = {k: "a" for k in exact_attrs + ["variable_id"]} + ds_match_a = xr.DataArray().to_dataset(name="test") + ds_match_b = xr.DataArray().to_dataset(name="test") + ds_match_a.attrs = ds.attrs + ds_match_b.attrs = ds.attrs + + match_msg = "Found more than one matching dataset for *" + with pytest.raises(ValueError, match=match_msg): + detrended = match_and_remove_trend( + {"aa": ds}, {"bb": ds_match_a, "cc": ds_match_b} + )