Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More extract features for RDRS #182

Merged
merged 13 commits into from
Apr 6, 2023
3 changes: 3 additions & 0 deletions HISTORY.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Contributors to this version: Trevor James Smith (:user:`Zeitsperre`), Juliette
Announcements
^^^^^^^^^^^^^
* `xscen` is now offered as a conda package available through Anaconda.org. Refer to the installation documentation for more information. (:issue:`149`, :pull:`171`).
* Deprecation: Release 0.6.0 of `xscen` will be the last version to support ``xscen.extract.clisops_subset``. (:pull:`182`).

New features and enhancements
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Expand All @@ -20,6 +21,8 @@ New features and enhancements
* Allow passing ``GeoDataFrame`` instances in ``spatial_mean``'s ``region`` argument, not only geospatial file paths. (:pull:`174`).
* Allow searching for periods in `catalog.search`. (:issue:`123`, :pull:`170`).
* Allow searching and extracting multiple frequencies for a given variable. (:issue:`168`, :pull:`170`).
* New masking feature in ``extract_dataset``. (:issue:`180`, :pull:`182`).
* New function ``xs.spatial.subset`` to replace ``xs.extract.clisops_subset`` and add method "sel". (:issue:`180`, :pull:`182`).

Breaking changes
^^^^^^^^^^^^^^^^
Expand Down
113 changes: 38 additions & 75 deletions xscen/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,14 @@
import re
import warnings
from collections import defaultdict
from copy import deepcopy
from pathlib import Path
from typing import Callable, List, Optional, Union

import clisops.core.subset
import dask
import numpy as np
import pandas as pd
import xarray as xr
import xclim as xc
from intake_esm.derived import DerivedVariableRegistry
from xclim.core.utils import uses_dask

from .catalog import DataCatalog # noqa
from .catalog import (
Expand All @@ -28,6 +24,7 @@
)
from .config import parse_config
from .indicators import load_xclim_module, registry_from_module
from .spatial import subset
from .utils import CV
from .utils import ensure_correct_time as _ensure_correct_time
from .utils import natural_sort
Expand Down Expand Up @@ -73,73 +70,13 @@ def clisops_subset(ds: xr.Dataset, region: dict) -> xr.Dataset:
--------
clisops.core.subset.subset_gridpoint, clisops.core.subset.subset_bbox, clisops.core.subset.subset_shape
"""
if uses_dask(ds.lon) or uses_dask(ds.lat):
warnings.warn("Loading longitude and latitude for more efficient subsetting.")
ds["lon"], ds["lat"] = dask.compute(ds.lon, ds.lat)
if "buffer" in region.keys():
# estimate the model resolution
if len(ds.lon.dims) == 1: # 1D lat-lon
lon_res = np.abs(ds.lon.diff("lon")[0].values)
lat_res = np.abs(ds.lat.diff("lat")[0].values)
else:
lon_res = np.abs(ds.lon[0, 0].values - ds.lon[0, 1].values)
lat_res = np.abs(ds.lat[0, 0].values - ds.lat[1, 0].values)

kwargs = deepcopy(region[region["method"]])

if region["method"] in ["gridpoint"]:
ds_subset = clisops.core.subset_gridpoint(ds, **kwargs)
new_history = (
f"[{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] "
f"{region['method']} spatial subsetting on {len(region['gridpoint']['lon'])} coordinates - clisops v{clisops.__version__}"
)

elif region["method"] in ["bbox"]:
if "buffer" in region.keys():
# adjust the boundaries
kwargs["lon_bnds"] = (
kwargs["lon_bnds"][0] - lon_res * region["buffer"],
kwargs["lon_bnds"][1] + lon_res * region["buffer"],
)
kwargs["lat_bnds"] = (
kwargs["lat_bnds"][0] - lat_res * region["buffer"],
kwargs["lat_bnds"][1] + lat_res * region["buffer"],
)

if xc.core.utils.uses_dask(ds.cf["longitude"]):
ds[ds.cf["longitude"].name].load()
if xc.core.utils.uses_dask(ds.cf["latitude"]):
ds[ds.cf["latitude"].name].load()

ds_subset = clisops.core.subset_bbox(ds, **kwargs)
new_history = (
f"[{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] "
f"{region['method']} spatial subsetting with {'buffer=' + str(region['buffer']) if 'buffer' in region else 'no buffer'}"
f", lon_bnds={np.array(region['bbox']['lon_bnds'])}, lat_bnds={np.array(region['bbox']['lat_bnds'])}"
f" - clisops v{clisops.__version__}"
)

elif region["method"] in ["shape"]:
if "buffer" in region.keys():
kwargs["buffer"] = np.max([lon_res, lat_res]) * region["buffer"]

ds_subset = clisops.core.subset_shape(ds, **kwargs)
new_history = (
f"[{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] "
f"{region['method']} spatial subsetting with {'buffer=' + str(region['buffer']) if 'buffer' in region else 'no buffer'}"
f", shape={Path(region['shape']['shape']).name if isinstance(region['shape']['shape'], (str, Path)) else 'gpd.GeoDataFrame'}"
f" - clisops v{clisops.__version__}"
)

else:
raise ValueError("Subsetting type not recognized")

history = (
new_history + " \n " + ds_subset.attrs["history"]
if "history" in ds_subset.attrs
else new_history
warnings.warn(
"clisops_subset is deprecated and will not be available in future versions. "
"Use xscen.spatial.subset instead.",
category=FutureWarning,
)
ds_subset.attrs["history"] = history

ds_subset = subset(ds, region)

return ds_subset

Expand All @@ -157,6 +94,7 @@ def extract_dataset(
xr_combine_kwargs: dict = None,
preprocess: Callable = None,
resample_methods: Optional[dict] = None,
mask: Union[bool, xr.Dataset, xr.DataArray] = False,
) -> Union[dict, xr.Dataset]:
"""Take one element of the output of `search_data_catalogs` and returns a dataset, performing conversions and resampling as needed.

Expand All @@ -174,7 +112,7 @@ def extract_dataset(
[start, end] of the period to be evaluated (or a list of lists)
Will be read from catalog._requested_periods if None. Leave both None to extract everything.
region : dict, optional
Description of the region and the subsetting method (required fields listed in the Notes).
Description of the region and the subsetting method (required fields listed in the Notes) used in `xscen.spatial.subset`.
to_level : str
The processing level to assign to the output.
Defaults to 'extracted'
Expand All @@ -197,6 +135,12 @@ def extract_dataset(
If the method is not given for a variable, it is guessed from the variable name and frequency,
using the mapping in CVs/resampling_methods.json. If the variable is not found there,
"mean" is used by default.
mask: xr.Dataset, bool
A mask that is applied to all variables and only keeps data where it is True.
Where the mask is False, variable values are replaced by NaNs.
The mask should have the same dimensions as the variables extracted.
If `mask` is a dataset, the dataset should have a variable named 'mask'.
If `mask` is True, it will expect a `mask` variable at xrfreq `fx` to have been extracted.

Returns
-------
Expand All @@ -211,7 +155,7 @@ def extract_dataset(
name: str
Region name used to overwrite domain in the catalog.
method: str
['gridpoint', 'bbox', shape']
['gridpoint', 'bbox', shape', 'sel']
<method>: dict
Arguments specific to the method used.
buffer: float, optional
Expand Down Expand Up @@ -375,10 +319,9 @@ def extract_dataset(
slices.extend([ds.sel({"time": slice(str(period[0]), str(period[1]))})])
ds = xr.concat(slices, dim="time", **xr_combine_kwargs)

# Custom call to clisops
# subset to the region
if region is not None:
ds = clisops_subset(ds, region)
ds.attrs["cat:domain"] = region["name"]
ds = subset(ds, region)

# add relevant attrs
ds.attrs["cat:processing_level"] = to_level
Expand All @@ -387,6 +330,26 @@ def extract_dataset(

out_dict[xrfreq] = ds

if mask:
if isinstance(mask, xr.Dataset):
ds_mask = mask["mask"]
elif isinstance(mask, xr.DataArray):
ds_mask = mask
elif (
"fx" in out_dict and "mask" in out_dict["fx"]
): # get mask that was extracted above
ds_mask = out_dict["fx"]["mask"].copy()
else:
raise ValueError(
"No mask found. Either pass a xr.Dataset/xr.DataArray to the `mask` argument or pass a `dc` that includes a dataset with a variable named `mask`."
)

# iter over all xrfreq to apply the mask
for xrfreq, ds in out_dict.items():
out_dict[xrfreq] = ds.where(ds_mask)
if xrfreq == "fx": # put back the mask
out_dict[xrfreq]["mask"] = ds_mask

return out_dict


Expand Down
133 changes: 133 additions & 0 deletions xscen/spatial.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,23 @@
"""Spatial tools."""
import datetime
import itertools
import warnings
from copy import deepcopy
from pathlib import Path

import clisops.core.subset
import dask
import numpy as np
import sparse as sp
import xarray as xr
import xclim as xc
from xclim.core.utils import uses_dask

__all__ = [
"creep_weights",
"creep_fill",
"subset",
]


def creep_weights(mask, n=1, mode="clip"):
Expand Down Expand Up @@ -103,3 +117,122 @@ def _dot(arr, wei):
dask="parallelized",
output_dtypes=["float64"],
)


def subset(ds: xr.Dataset, region: dict) -> xr.Dataset:
"""
Subset the data to a region.

Either creates a slice and uses the .sel() method or customize a call to
clisops.subset() that allows for an automatic buffer around the region.

Parameters
----------
ds : xr.Dataset
Dataset to be subsetted
region : dict
Description of the region and the subsetting method (required fields listed in the Notes)

Notes
-----
'region' fields:
name: str
Region name used to overwrite domain in the catalog.
method: str
['gridpoint', 'bbox', shape','sel']
If the method is `sel`, this is not a call to clisops but only a subsetting with the xarray .sel() fonction.
The keys are the dimensions to subset and the values are turned into a slice.
<method>: dict
Arguments specific to the method used.
buffer: float, optional
Multiplier to apply to the model resolution.

Returns
-------
xr.Dataset
Subsetted Dataset.

See Also
--------
clisops.core.subset.subset_gridpoint, clisops.core.subset.subset_bbox, clisops.core.subset.subset_shape
"""
if uses_dask(ds.lon) or uses_dask(ds.lat):
warnings.warn("Loading longitude and latitude for more efficient subsetting.")
ds["lon"], ds["lat"] = dask.compute(ds.lon, ds.lat)
if "buffer" in region.keys():
# estimate the model resolution
if len(ds.lon.dims) == 1: # 1D lat-lon
lon_res = np.abs(ds.lon.diff("lon")[0].values)
lat_res = np.abs(ds.lat.diff("lat")[0].values)
else:
lon_res = np.abs(ds.lon[0, 0].values - ds.lon[0, 1].values)
lat_res = np.abs(ds.lat[0, 0].values - ds.lat[1, 0].values)

kwargs = deepcopy(region[region["method"]])

if region["method"] in ["gridpoint"]:
ds_subset = clisops.core.subset_gridpoint(ds, **kwargs)
new_history = (
f"[{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] "
f"{region['method']} spatial subsetting on {len(region['gridpoint']['lon'])} coordinates - clisops v{clisops.__version__}"
)

elif region["method"] in ["bbox"]:
if "buffer" in region.keys():
# adjust the boundaries
kwargs["lon_bnds"] = (
kwargs["lon_bnds"][0] - lon_res * region["buffer"],
kwargs["lon_bnds"][1] + lon_res * region["buffer"],
)
kwargs["lat_bnds"] = (
kwargs["lat_bnds"][0] - lat_res * region["buffer"],
kwargs["lat_bnds"][1] + lat_res * region["buffer"],
)

if xc.core.utils.uses_dask(ds.cf["longitude"]):
ds[ds.cf["longitude"].name].load()
if xc.core.utils.uses_dask(ds.cf["latitude"]):
ds[ds.cf["latitude"].name].load()

ds_subset = clisops.core.subset_bbox(ds, **kwargs)
new_history = (
f"[{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] "
f"{region['method']} spatial subsetting with {'buffer=' + str(region['buffer']) if 'buffer' in region else 'no buffer'}"
f", lon_bnds={np.array(region['bbox']['lon_bnds'])}, lat_bnds={np.array(region['bbox']['lat_bnds'])}"
f" - clisops v{clisops.__version__}"
)

elif region["method"] in ["shape"]:
if "buffer" in region.keys():
kwargs["buffer"] = np.max([lon_res, lat_res]) * region["buffer"]

ds_subset = clisops.core.subset_shape(ds, **kwargs)
new_history = (
f"[{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] "
f"{region['method']} spatial subsetting with {'buffer=' + str(region['buffer']) if 'buffer' in region else 'no buffer'}"
f", shape={Path(region['shape']['shape']).name if isinstance(region['shape']['shape'], (str, Path)) else 'gpd.GeoDataFrame'}"
f" - clisops v{clisops.__version__}"
)

elif region["method"] in ["sel"]:
arg_sel = {
dim: slice(*map(float, bounds)) for dim, bounds in region["sel"].items()
}
ds_subset = ds.sel(**arg_sel)
new_history = (
f"[{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] "
f"{region['method']} subsetting with arguments {arg_sel}"
)

else:
raise ValueError("Subsetting type not recognized")

history = (
new_history + " \n " + ds_subset.attrs["history"]
if "history" in ds_subset.attrs
else new_history
)
ds_subset.attrs["history"] = history
ds_subset.attrs["cat:domain"] = region["name"]

return ds_subset