From 1d2e5c608765bef788f22191de54d1baeaaa7767 Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Date: Wed, 20 Sep 2023 01:23:23 -0700 Subject: [PATCH 01/46] Add comments on when to use which `TypeVar` (#8212) * Add comments on when to use which `TypeVar` From the chars of #8208 --- xarray/core/types.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/xarray/core/types.py b/xarray/core/types.py index f80c2c52cd7..e9e700b038e 100644 --- a/xarray/core/types.py +++ b/xarray/core/types.py @@ -154,11 +154,21 @@ def copy( T_Array = TypeVar("T_Array", bound="AbstractArray") T_Index = TypeVar("T_Index", bound="Index") +# `T_Xarray` is a type variable that can be either "DataArray" or "Dataset". When used +# in a function definition, all inputs and outputs annotated with `T_Xarray` must be of +# the same concrete type, either "DataArray" or "Dataset". This is generally preferred +# over `T_DataArrayOrSet`, given the type system can determine the exact type. +T_Xarray = TypeVar("T_Xarray", "DataArray", "Dataset") + +# `T_DataArrayOrSet` is a type variable that is bounded to either "DataArray" or +# "Dataset". Use it for functions that might return either type, but where the exact +# type cannot be determined statically using the type system. T_DataArrayOrSet = TypeVar("T_DataArrayOrSet", bound=Union["Dataset", "DataArray"]) -# Maybe we rename this to T_Data or something less Fortran-y? -T_Xarray = TypeVar("T_Xarray", "DataArray", "Dataset") +# For working directly with `DataWithCoords`. It will only allow using methods defined +# on `DataWithCoords`. T_DataWithCoords = TypeVar("T_DataWithCoords", bound="DataWithCoords") + T_Alignable = TypeVar("T_Alignable", bound="Alignable") # Temporary placeholder for indicating an array api compliant type. From 8c21376aa94deb91cb10e78f56ecd68aa7fe90fa Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Date: Wed, 20 Sep 2023 02:38:32 -0700 Subject: [PATCH 02/46] Start a list of modules which require typing (#8198) * Start a list of modules which require typing Notes inline. Just one module so far! --- pyproject.toml | 11 ++++++++++- xarray/core/rolling_exp.py | 20 +++++++++++++++----- 2 files changed, 25 insertions(+), 6 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index dd380937bd2..663920f8dbb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -79,7 +79,7 @@ show_error_context = true warn_redundant_casts = true warn_unused_ignores = true -# Most of the numerical computing stack doesn't have type annotations yet. +# Much of the numerical computing stack doesn't have type annotations yet. [[tool.mypy.overrides]] ignore_missing_imports = true module = [ @@ -118,6 +118,15 @@ module = [ "numpy.exceptions.*", # remove once support for `numpy<2.0` has been dropped ] +# Gradually we want to add more modules to this list, ratcheting up our total +# coverage. Once a module is here, functions require annotations in order to +# pass mypy. It would be especially useful to have tests here, because without +# annotating test functions, we don't have a great way of testing our type +# annotations — even with just `-> None` is sufficient for mypy to check them. +[[tool.mypy.overrides]] +disallow_untyped_defs = true +module = ["xarray.core.rolling_exp"] + [tool.ruff] builtins = ["ellipsis"] exclude = [ diff --git a/xarray/core/rolling_exp.py b/xarray/core/rolling_exp.py index bd30c634aae..c56bf6a384e 100644 --- a/xarray/core/rolling_exp.py +++ b/xarray/core/rolling_exp.py @@ -9,10 +9,15 @@ from xarray.core.options import _get_keep_attrs from xarray.core.pdcompat import count_not_none from xarray.core.pycompat import is_duck_dask_array -from xarray.core.types import T_DataWithCoords +from xarray.core.types import T_DataWithCoords, T_DuckArray -def _get_alpha(com=None, span=None, halflife=None, alpha=None): +def _get_alpha( + com: float | None = None, + span: float | None = None, + halflife: float | None = None, + alpha: float | None = None, +) -> float: # pandas defines in terms of com (converting to alpha in the algo) # so use its function to get a com and then convert to alpha @@ -20,7 +25,7 @@ def _get_alpha(com=None, span=None, halflife=None, alpha=None): return 1 / (1 + com) -def move_exp_nanmean(array, *, axis, alpha): +def move_exp_nanmean(array: T_DuckArray, *, axis: int, alpha: float) -> np.ndarray: if is_duck_dask_array(array): raise TypeError("rolling_exp is not currently support for dask-like arrays") import numbagg @@ -32,7 +37,7 @@ def move_exp_nanmean(array, *, axis, alpha): return numbagg.move_exp_nanmean(array, axis=axis, alpha=alpha) -def move_exp_nansum(array, *, axis, alpha): +def move_exp_nansum(array: T_DuckArray, *, axis: int, alpha: float) -> np.ndarray: if is_duck_dask_array(array): raise TypeError("rolling_exp is not currently supported for dask-like arrays") import numbagg @@ -40,7 +45,12 @@ def move_exp_nansum(array, *, axis, alpha): return numbagg.move_exp_nansum(array, axis=axis, alpha=alpha) -def _get_center_of_mass(comass, span, halflife, alpha): +def _get_center_of_mass( + comass: float | None, + span: float | None, + halflife: float | None, + alpha: float | None, +) -> float: """ Vendored from pandas.core.window.common._get_center_of_mass From 04550e64089e58646be23d83695f32a1669db8eb Mon Sep 17 00:00:00 2001 From: Riulinchen <119889091+Riulinchen@users.noreply.github.com> Date: Wed, 20 Sep 2023 21:25:57 +0200 Subject: [PATCH 03/46] Make documentation of DataArray.where clearer (#7955) * Make doc of DataArray.where clearer * Update xarray/core/common.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: Loren Co-authored-by: Deepak Cherian Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- xarray/core/common.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/xarray/core/common.py b/xarray/core/common.py index ade701457c6..224b4154ef8 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -1066,6 +1066,9 @@ def where( ) -> T_DataWithCoords: """Filter elements from this object according to a condition. + Returns elements from 'DataArray', where 'cond' is True, + otherwise fill in 'other'. + This operation follows the normal broadcasting and alignment rules that xarray uses for binary arithmetic. From 2b784f24548a28a88a373c98722e562c7ddc7e01 Mon Sep 17 00:00:00 2001 From: Amrest Chinkamol Date: Thu, 21 Sep 2023 02:31:49 +0700 Subject: [PATCH 04/46] Consistent `DatasetRolling.construct` behavior (#7578) * Removed `.isel` for consistent rolling behavior. `.isel` causes `DatasetRolling.construct` to behavior to be inconsistent with `DataArrayRolling.construct` when `stride` > 1. * new rolling construct strategy for coords * add whats-new * add new tests with different coords * next try on aligning strided coords * add peakmem test for rolling.construct * increase asv benchmark rolling sizes --------- Co-authored-by: Michael Niklas --- asv_bench/benchmarks/rolling.py | 13 ++++++-- doc/whats-new.rst | 5 ++- xarray/core/rolling.py | 9 ++++-- xarray/tests/test_rolling.py | 55 ++++++++++++++++++++++++++++----- 4 files changed, 68 insertions(+), 14 deletions(-) diff --git a/asv_bench/benchmarks/rolling.py b/asv_bench/benchmarks/rolling.py index 1d3713f19bf..579f4f00fbc 100644 --- a/asv_bench/benchmarks/rolling.py +++ b/asv_bench/benchmarks/rolling.py @@ -5,10 +5,10 @@ from . import parameterized, randn, requires_dask -nx = 300 +nx = 3000 long_nx = 30000 ny = 200 -nt = 100 +nt = 1000 window = 20 randn_xy = randn((nx, ny), frac_nan=0.1) @@ -115,6 +115,11 @@ def peakmem_1drolling_reduce(self, func, use_bottleneck): roll = self.ds.var3.rolling(t=100) getattr(roll, func)() + @parameterized(["stride"], ([None, 5, 50])) + def peakmem_1drolling_construct(self, stride): + self.ds.var2.rolling(t=100).construct("w", stride=stride) + self.ds.var3.rolling(t=100).construct("w", stride=stride) + class DatasetRollingMemory(RollingMemory): @parameterized(["func", "use_bottleneck"], (["sum", "max", "mean"], [True, False])) @@ -128,3 +133,7 @@ def peakmem_1drolling_reduce(self, func, use_bottleneck): with xr.set_options(use_bottleneck=use_bottleneck): roll = self.ds.rolling(t=100) getattr(roll, func)() + + @parameterized(["stride"], ([None, 5, 50])) + def peakmem_1drolling_construct(self, stride): + self.ds.rolling(t=100).construct("w", stride=stride) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index ee74411a004..67429ed7e18 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -74,9 +74,12 @@ Bug fixes of :py:meth:`DataArray.__setitem__` lose dimension names. (:issue:`7030`, :pull:`8067`) By `Darsh Ranjan `_. - Return ``float64`` in presence of ``NaT`` in :py:class:`~core.accessor_dt.DatetimeAccessor` and - special case ``NaT`` handling in :py:meth:`~core.accessor_dt.DatetimeAccessor.isocalendar()` + special case ``NaT`` handling in :py:meth:`~core.accessor_dt.DatetimeAccessor.isocalendar` (:issue:`7928`, :pull:`8084`). By `Kai Mühlbauer `_. +- Fix :py:meth:`~core.rolling.DatasetRolling.construct` with stride on Datasets without indexes. + (:issue:`7021`, :pull:`7578`). + By `Amrest Chinkamol `_ and `Michael Niklas `_. - Calling plot with kwargs ``col``, ``row`` or ``hue`` no longer squeezes dimensions passed via these arguments (:issue:`7552`, :pull:`8174`). By `Wiktor Kraśnicki `_. diff --git a/xarray/core/rolling.py b/xarray/core/rolling.py index d49cb6e13a4..c6911cbe65b 100644 --- a/xarray/core/rolling.py +++ b/xarray/core/rolling.py @@ -785,11 +785,14 @@ def construct( if not keep_attrs: dataset[key].attrs = {} + # Need to stride coords as well. TODO: is there a better way? + coords = self.obj.isel( + {d: slice(None, None, s) for d, s in zip(self.dim, strides)} + ).coords + attrs = self.obj.attrs if keep_attrs else {} - return Dataset(dataset, coords=self.obj.coords, attrs=attrs).isel( - {d: slice(None, None, s) for d, s in zip(self.dim, strides)} - ) + return Dataset(dataset, coords=coords, attrs=attrs) class Coarsen(CoarsenArithmetic, Generic[T_Xarray]): diff --git a/xarray/tests/test_rolling.py b/xarray/tests/test_rolling.py index 9a15696b004..72d1b9071dd 100644 --- a/xarray/tests/test_rolling.py +++ b/xarray/tests/test_rolling.py @@ -175,7 +175,7 @@ def test_rolling_pandas_compat(self, center, window, min_periods) -> None: @pytest.mark.parametrize("center", (True, False)) @pytest.mark.parametrize("window", (1, 2, 3, 4)) - def test_rolling_construct(self, center, window) -> None: + def test_rolling_construct(self, center: bool, window: int) -> None: s = pd.Series(np.arange(10)) da = DataArray.from_series(s) @@ -610,7 +610,7 @@ def test_rolling_pandas_compat(self, center, window, min_periods) -> None: @pytest.mark.parametrize("center", (True, False)) @pytest.mark.parametrize("window", (1, 2, 3, 4)) - def test_rolling_construct(self, center, window) -> None: + def test_rolling_construct(self, center: bool, window: int) -> None: df = pd.DataFrame( { "x": np.random.randn(20), @@ -627,12 +627,6 @@ def test_rolling_construct(self, center, window) -> None: np.testing.assert_allclose(df_rolling["x"].values, ds_rolling_mean["x"].values) np.testing.assert_allclose(df_rolling.index, ds_rolling_mean["index"]) - # with stride - ds_rolling_mean = ds_rolling.construct("window", stride=2).mean("window") - np.testing.assert_allclose( - df_rolling["x"][::2].values, ds_rolling_mean["x"].values - ) - np.testing.assert_allclose(df_rolling.index[::2], ds_rolling_mean["index"]) # with fill_value ds_rolling_mean = ds_rolling.construct("window", stride=2, fill_value=0.0).mean( "window" @@ -640,6 +634,51 @@ def test_rolling_construct(self, center, window) -> None: assert (ds_rolling_mean.isnull().sum() == 0).to_array(dim="vars").all() assert (ds_rolling_mean["x"] == 0.0).sum() >= 0 + @pytest.mark.parametrize("center", (True, False)) + @pytest.mark.parametrize("window", (1, 2, 3, 4)) + def test_rolling_construct_stride(self, center: bool, window: int) -> None: + df = pd.DataFrame( + { + "x": np.random.randn(20), + "y": np.random.randn(20), + "time": np.linspace(0, 1, 20), + } + ) + ds = Dataset.from_dataframe(df) + df_rolling_mean = df.rolling(window, center=center, min_periods=1).mean() + + # With an index (dimension coordinate) + ds_rolling = ds.rolling(index=window, center=center) + ds_rolling_mean = ds_rolling.construct("w", stride=2).mean("w") + np.testing.assert_allclose( + df_rolling_mean["x"][::2].values, ds_rolling_mean["x"].values + ) + np.testing.assert_allclose(df_rolling_mean.index[::2], ds_rolling_mean["index"]) + + # Without index (https://github.com/pydata/xarray/issues/7021) + ds2 = ds.drop_vars("index") + ds2_rolling = ds2.rolling(index=window, center=center) + ds2_rolling_mean = ds2_rolling.construct("w", stride=2).mean("w") + np.testing.assert_allclose( + df_rolling_mean["x"][::2].values, ds2_rolling_mean["x"].values + ) + + # Mixed coordinates, indexes and 2D coordinates + ds3 = xr.Dataset( + {"x": ("t", range(20)), "x2": ("y", range(5))}, + { + "t": range(20), + "y": ("y", range(5)), + "t2": ("t", range(20)), + "y2": ("y", range(5)), + "yt": (["t", "y"], np.ones((20, 5))), + }, + ) + ds3_rolling = ds3.rolling(t=window, center=center) + ds3_rolling_mean = ds3_rolling.construct("w", stride=2).mean("w") + for coord in ds3.coords: + assert coord in ds3_rolling_mean.coords + @pytest.mark.slow @pytest.mark.parametrize("ds", (1, 2), indirect=True) @pytest.mark.parametrize("center", (True, False)) From f60c1a3e969bfa580dfecca2e9ba7fee71447d9b Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Date: Wed, 20 Sep 2023 12:57:16 -0700 Subject: [PATCH 05/46] Skip flaky test (#8219) * Skip flaky test --- xarray/tests/test_distributed.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/xarray/tests/test_distributed.py b/xarray/tests/test_distributed.py index 6a8cd9c457b..bfc37121597 100644 --- a/xarray/tests/test_distributed.py +++ b/xarray/tests/test_distributed.py @@ -168,6 +168,10 @@ def test_open_mfdataset_multiple_files_parallel_distributed(parallel, tmp_path): @requires_netCDF4 @pytest.mark.parametrize("parallel", (True, False)) def test_open_mfdataset_multiple_files_parallel(parallel, tmp_path): + if parallel: + pytest.skip( + "Flaky in CI. Would be a welcome contribution to make a similar test reliable." + ) lon = np.arange(100) time = xr.cftime_range("20010101", periods=100, calendar="360_day") data = np.random.random((time.size, lon.size)) From 96cf77a5ceaf849f8b867b4edc873bcb651a0b04 Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Date: Wed, 20 Sep 2023 15:57:53 -0700 Subject: [PATCH 06/46] Convert `indexes.py` to use `Self` for typing (#8217) * Convert `Variable` to use `Self` for typing I wanted to do this separately, as it's the only one that adds some casts. And given the ratio of impact-to-potential-merge-conflicts, I didn't want to slow the other PR down, even if it seems to be OK. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update xarray/core/indexes.py Co-authored-by: Michael Niklas * Update xarray/core/indexes.py Co-authored-by: Michael Niklas * Update xarray/core/indexes.py Co-authored-by: Michael Niklas * Update xarray/core/indexes.py Co-authored-by: Michael Niklas * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Michael Niklas --- xarray/core/indexes.py | 63 +++++++++++++++++++++--------------------- 1 file changed, 32 insertions(+), 31 deletions(-) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 9972896d6df..1697762f7ae 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -24,7 +24,7 @@ ) if TYPE_CHECKING: - from xarray.core.types import ErrorOptions, JoinOptions, T_Index + from xarray.core.types import ErrorOptions, JoinOptions, Self from xarray.core.variable import Variable @@ -60,11 +60,11 @@ class Index: @classmethod def from_variables( - cls: type[T_Index], + cls, variables: Mapping[Any, Variable], *, options: Mapping[str, Any], - ) -> T_Index: + ) -> Self: """Create a new index object from one or more coordinate variables. This factory method must be implemented in all subclasses of Index. @@ -88,11 +88,11 @@ def from_variables( @classmethod def concat( - cls: type[T_Index], - indexes: Sequence[T_Index], + cls, + indexes: Sequence[Self], dim: Hashable, positions: Iterable[Iterable[int]] | None = None, - ) -> T_Index: + ) -> Self: """Create a new index by concatenating one or more indexes of the same type. @@ -120,9 +120,7 @@ def concat( raise NotImplementedError() @classmethod - def stack( - cls: type[T_Index], variables: Mapping[Any, Variable], dim: Hashable - ) -> T_Index: + def stack(cls, variables: Mapping[Any, Variable], dim: Hashable) -> Self: """Create a new index by stacking coordinate variables into a single new dimension. @@ -208,8 +206,8 @@ def to_pandas_index(self) -> pd.Index: raise TypeError(f"{self!r} cannot be cast to a pandas.Index object") def isel( - self: T_Index, indexers: Mapping[Any, int | slice | np.ndarray | Variable] - ) -> T_Index | None: + self, indexers: Mapping[Any, int | slice | np.ndarray | Variable] + ) -> Self | None: """Maybe returns a new index from the current index itself indexed by positional indexers. @@ -264,7 +262,7 @@ def sel(self, labels: dict[Any, Any]) -> IndexSelResult: """ raise NotImplementedError(f"{self!r} doesn't support label-based selection") - def join(self: T_Index, other: T_Index, how: JoinOptions = "inner") -> T_Index: + def join(self, other: Self, how: JoinOptions = "inner") -> Self: """Return a new index from the combination of this index with another index of the same type. @@ -286,7 +284,7 @@ def join(self: T_Index, other: T_Index, how: JoinOptions = "inner") -> T_Index: f"{self!r} doesn't support alignment with inner/outer join method" ) - def reindex_like(self: T_Index, other: T_Index) -> dict[Hashable, Any]: + def reindex_like(self, other: Self) -> dict[Hashable, Any]: """Query the index with another index of the same type. Implementation is optional but required in order to support alignment. @@ -304,7 +302,7 @@ def reindex_like(self: T_Index, other: T_Index) -> dict[Hashable, Any]: """ raise NotImplementedError(f"{self!r} doesn't support re-indexing labels") - def equals(self: T_Index, other: T_Index) -> bool: + def equals(self, other: Self) -> bool: """Compare this index with another index of the same type. Implementation is optional but required in order to support alignment. @@ -321,7 +319,7 @@ def equals(self: T_Index, other: T_Index) -> bool: """ raise NotImplementedError() - def roll(self: T_Index, shifts: Mapping[Any, int]) -> T_Index | None: + def roll(self, shifts: Mapping[Any, int]) -> Self | None: """Roll this index by an offset along one or more dimensions. This method can be re-implemented in subclasses of Index, e.g., when the @@ -347,10 +345,10 @@ def roll(self: T_Index, shifts: Mapping[Any, int]) -> T_Index | None: return None def rename( - self: T_Index, + self, name_dict: Mapping[Any, Hashable], dims_dict: Mapping[Any, Hashable], - ) -> T_Index: + ) -> Self: """Maybe update the index with new coordinate and dimension names. This method should be re-implemented in subclasses of Index if it has @@ -377,7 +375,7 @@ def rename( """ return self - def copy(self: T_Index, deep: bool = True) -> T_Index: + def copy(self, deep: bool = True) -> Self: """Return a (deep) copy of this index. Implementation in subclasses of Index is optional. The base class @@ -396,15 +394,13 @@ def copy(self: T_Index, deep: bool = True) -> T_Index: """ return self._copy(deep=deep) - def __copy__(self: T_Index) -> T_Index: + def __copy__(self) -> Self: return self.copy(deep=False) def __deepcopy__(self, memo: dict[int, Any] | None = None) -> Index: return self._copy(deep=True, memo=memo) - def _copy( - self: T_Index, deep: bool = True, memo: dict[int, Any] | None = None - ) -> T_Index: + def _copy(self, deep: bool = True, memo: dict[int, Any] | None = None) -> Self: cls = self.__class__ copied = cls.__new__(cls) if deep: @@ -414,7 +410,7 @@ def _copy( copied.__dict__.update(self.__dict__) return copied - def __getitem__(self: T_Index, indexer: Any) -> T_Index: + def __getitem__(self, indexer: Any) -> Self: raise NotImplementedError() def _repr_inline_(self, max_width): @@ -674,10 +670,10 @@ def _concat_indexes(indexes, dim, positions=None) -> pd.Index: @classmethod def concat( cls, - indexes: Sequence[PandasIndex], + indexes: Sequence[Self], dim: Hashable, positions: Iterable[Iterable[int]] | None = None, - ) -> PandasIndex: + ) -> Self: new_pd_index = cls._concat_indexes(indexes, dim, positions) if not indexes: @@ -800,7 +796,11 @@ def equals(self, other: Index): return False return self.index.equals(other.index) and self.dim == other.dim - def join(self: PandasIndex, other: PandasIndex, how: str = "inner") -> PandasIndex: + def join( + self, + other: Self, + how: str = "inner", + ) -> Self: if how == "outer": index = self.index.union(other.index) else: @@ -811,7 +811,7 @@ def join(self: PandasIndex, other: PandasIndex, how: str = "inner") -> PandasInd return type(self)(index, self.dim, coord_dtype=coord_dtype) def reindex_like( - self, other: PandasIndex, method=None, tolerance=None + self, other: Self, method=None, tolerance=None ) -> dict[Hashable, Any]: if not self.index.is_unique: raise ValueError( @@ -963,12 +963,12 @@ def from_variables( return obj @classmethod - def concat( # type: ignore[override] + def concat( cls, - indexes: Sequence[PandasMultiIndex], + indexes: Sequence[Self], dim: Hashable, positions: Iterable[Iterable[int]] | None = None, - ) -> PandasMultiIndex: + ) -> Self: new_pd_index = cls._concat_indexes(indexes, dim, positions) if not indexes: @@ -1602,7 +1602,7 @@ def to_pandas_indexes(self) -> Indexes[pd.Index]: return Indexes(indexes, self._variables, index_type=pd.Index) def copy_indexes( - self, deep: bool = True, memo: dict[int, Any] | None = None + self, deep: bool = True, memo: dict[int, T_PandasOrXarrayIndex] | None = None ) -> tuple[dict[Hashable, T_PandasOrXarrayIndex], dict[Hashable, Variable]]: """Return a new dictionary with copies of indexes, preserving unique indexes. @@ -1619,6 +1619,7 @@ def copy_indexes( new_indexes = {} new_index_vars = {} + idx: T_PandasOrXarrayIndex for idx, coords in self.group_by_index(): if isinstance(idx, pd.Index): convert_new_idx = True From 3ace2fb4612d4bc1cbce6fa22fe3954a0e06599e Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Date: Wed, 20 Sep 2023 18:53:40 -0700 Subject: [PATCH 07/46] Use `Self` rather than concrete types, remove `cast`s (#8216) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Use `Self` rather than concrete types, remove `cast`s This should also allow for subtyping * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Undo one `Self` * Unused ignore * Add check for redundant self annotations * And `DataWithCoords` * And `DataArray` & `Dataset` * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * And `Variable` * Update xarray/core/dataarray.py Co-authored-by: Michael Niklas * Update xarray/core/dataarray.py Co-authored-by: Michael Niklas * Update xarray/core/dataarray.py Co-authored-by: Michael Niklas * Clean-ups — `other`, casts, obsolete comments * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * another one --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Michael Niklas --- pyproject.toml | 1 + xarray/core/accessor_str.py | 2 +- xarray/core/common.py | 37 ++-- xarray/core/concat.py | 8 +- xarray/core/coordinates.py | 28 +-- xarray/core/dataarray.py | 331 ++++++++++++++---------------- xarray/core/dataset.py | 364 ++++++++++++++++----------------- xarray/core/variable.py | 60 +++--- xarray/tests/test_dataarray.py | 2 +- 9 files changed, 400 insertions(+), 433 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 663920f8dbb..cb51c6ea741 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,6 +72,7 @@ source = ["xarray"] exclude_lines = ["pragma: no cover", "if TYPE_CHECKING"] [tool.mypy] +enable_error_code = "redundant-self" exclude = 'xarray/util/generate_.*\.py' files = "xarray" show_error_codes = true diff --git a/xarray/core/accessor_str.py b/xarray/core/accessor_str.py index aa6dc2c7114..573200b5c88 100644 --- a/xarray/core/accessor_str.py +++ b/xarray/core/accessor_str.py @@ -2386,7 +2386,7 @@ def _partitioner( # _apply breaks on an empty array in this case if not self._obj.size: - return self._obj.copy().expand_dims({dim: 0}, axis=-1) # type: ignore[return-value] + return self._obj.copy().expand_dims({dim: 0}, axis=-1) arrfunc = lambda x, isep: np.array(func(x, isep), dtype=self._obj.dtype) diff --git a/xarray/core/common.py b/xarray/core/common.py index 224b4154ef8..e4e3e60e815 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -45,6 +45,7 @@ DatetimeLike, DTypeLikeSave, ScalarOrArray, + Self, SideOptions, T_Chunks, T_DataWithCoords, @@ -381,11 +382,11 @@ class DataWithCoords(AttrAccessMixin): __slots__ = ("_close",) def squeeze( - self: T_DataWithCoords, + self, dim: Hashable | Iterable[Hashable] | None = None, drop: bool = False, axis: int | Iterable[int] | None = None, - ) -> T_DataWithCoords: + ) -> Self: """Return a new object with squeezed data. Parameters @@ -414,12 +415,12 @@ def squeeze( return self.isel(drop=drop, **{d: 0 for d in dims}) def clip( - self: T_DataWithCoords, + self, min: ScalarOrArray | None = None, max: ScalarOrArray | None = None, *, keep_attrs: bool | None = None, - ) -> T_DataWithCoords: + ) -> Self: """ Return an array whose values are limited to ``[min, max]``. At least one of max or min must be given. @@ -472,10 +473,10 @@ def _calc_assign_results( return {k: v(self) if callable(v) else v for k, v in kwargs.items()} def assign_coords( - self: T_DataWithCoords, + self, coords: Mapping[Any, Any] | None = None, **coords_kwargs: Any, - ) -> T_DataWithCoords: + ) -> Self: """Assign new coordinates to this object. Returns a new object with all the original data in addition to the new @@ -620,9 +621,7 @@ def assign_coords( data.coords.update(results) return data - def assign_attrs( - self: T_DataWithCoords, *args: Any, **kwargs: Any - ) -> T_DataWithCoords: + def assign_attrs(self, *args: Any, **kwargs: Any) -> Self: """Assign new attrs to this object. Returns a new object equivalent to ``self.attrs.update(*args, **kwargs)``. @@ -1061,9 +1060,7 @@ def _resample( restore_coord_dims=restore_coord_dims, ) - def where( - self: T_DataWithCoords, cond: Any, other: Any = dtypes.NA, drop: bool = False - ) -> T_DataWithCoords: + def where(self, cond: Any, other: Any = dtypes.NA, drop: bool = False) -> Self: """Filter elements from this object according to a condition. Returns elements from 'DataArray', where 'cond' is True, @@ -1208,9 +1205,7 @@ def close(self) -> None: self._close() self._close = None - def isnull( - self: T_DataWithCoords, keep_attrs: bool | None = None - ) -> T_DataWithCoords: + def isnull(self, keep_attrs: bool | None = None) -> Self: """Test each value in the array for whether it is a missing value. Parameters @@ -1253,9 +1248,7 @@ def isnull( keep_attrs=keep_attrs, ) - def notnull( - self: T_DataWithCoords, keep_attrs: bool | None = None - ) -> T_DataWithCoords: + def notnull(self, keep_attrs: bool | None = None) -> Self: """Test each value in the array for whether it is not a missing value. Parameters @@ -1298,7 +1291,7 @@ def notnull( keep_attrs=keep_attrs, ) - def isin(self: T_DataWithCoords, test_elements: Any) -> T_DataWithCoords: + def isin(self, test_elements: Any) -> Self: """Tests each value in the array for whether it is in test elements. Parameters @@ -1347,7 +1340,7 @@ def isin(self: T_DataWithCoords, test_elements: Any) -> T_DataWithCoords: ) def astype( - self: T_DataWithCoords, + self, dtype, *, order=None, @@ -1355,7 +1348,7 @@ def astype( subok=None, copy=None, keep_attrs=True, - ) -> T_DataWithCoords: + ) -> Self: """ Copy of the xarray object, with data cast to a specified type. Leaves coordinate dtype unchanged. @@ -1422,7 +1415,7 @@ def astype( dask="allowed", ) - def __enter__(self: T_DataWithCoords) -> T_DataWithCoords: + def __enter__(self) -> Self: return self def __exit__(self, exc_type, exc_value, traceback) -> None: diff --git a/xarray/core/concat.py b/xarray/core/concat.py index a76bb6b0033..a136480b2fb 100644 --- a/xarray/core/concat.py +++ b/xarray/core/concat.py @@ -1,7 +1,7 @@ from __future__ import annotations from collections.abc import Hashable, Iterable -from typing import TYPE_CHECKING, Any, Union, cast, overload +from typing import TYPE_CHECKING, Any, Union, overload import numpy as np import pandas as pd @@ -504,8 +504,7 @@ def _dataset_concat( # case where concat dimension is a coordinate or data_var but not a dimension if (dim in coord_names or dim in data_names) and dim not in dim_names: - # TODO: Overriding type because .expand_dims has incorrect typing: - datasets = [cast(T_Dataset, ds.expand_dims(dim)) for ds in datasets] + datasets = [ds.expand_dims(dim) for ds in datasets] # determine which variables to concatenate concat_over, equals, concat_dim_lengths = _calc_concat_over( @@ -708,8 +707,7 @@ def _dataarray_concat( if compat == "identical": raise ValueError("array names not identical") else: - # TODO: Overriding type because .rename has incorrect typing: - arr = cast(T_DataArray, arr.rename(name)) + arr = arr.rename(name) datasets.append(arr._to_temp_dataset()) ds = _dataset_concat( diff --git a/xarray/core/coordinates.py b/xarray/core/coordinates.py index e20c022e637..97ba383ebde 100644 --- a/xarray/core/coordinates.py +++ b/xarray/core/coordinates.py @@ -23,7 +23,7 @@ create_default_index_implicit, ) from xarray.core.merge import merge_coordinates_without_align, merge_coords -from xarray.core.types import Self, T_DataArray +from xarray.core.types import Self, T_DataArray, T_Xarray from xarray.core.utils import ( Frozen, ReprObject, @@ -425,7 +425,7 @@ def __delitem__(self, key: Hashable) -> None: # redirect to DatasetCoordinates.__delitem__ del self._data.coords[key] - def equals(self, other: Coordinates) -> bool: + def equals(self, other: Self) -> bool: """Two Coordinates objects are equal if they have matching variables, all of which are equal. @@ -437,7 +437,7 @@ def equals(self, other: Coordinates) -> bool: return False return self.to_dataset().equals(other.to_dataset()) - def identical(self, other: Coordinates) -> bool: + def identical(self, other: Self) -> bool: """Like equals, but also checks all variable attributes. See Also @@ -565,9 +565,7 @@ def update(self, other: Mapping[Any, Any]) -> None: self._update_coords(coords, indexes) - def assign( - self, coords: Mapping | None = None, **coords_kwargs: Any - ) -> Coordinates: + def assign(self, coords: Mapping | None = None, **coords_kwargs: Any) -> Self: """Assign new coordinates (and indexes) to a Coordinates object, returning a new object with all the original coordinates in addition to the new ones. @@ -656,7 +654,7 @@ def copy( self, deep: bool = False, memo: dict[int, Any] | None = None, - ) -> Coordinates: + ) -> Self: """Return a copy of this Coordinates object.""" # do not copy indexes (may corrupt multi-coordinate indexes) # TODO: disable variables deepcopy? it may also be problematic when they @@ -664,8 +662,16 @@ def copy( variables = { k: v._copy(deep=deep, memo=memo) for k, v in self.variables.items() } - return Coordinates._construct_direct( - coords=variables, indexes=dict(self.xindexes), dims=dict(self.sizes) + + # TODO: getting an error with `self._construct_direct`, possibly because of how + # a subclass implements `_construct_direct`. (This was originally the same + # runtime code, but we switched the type definitions in #8216, which + # necessitates the cast.) + return cast( + Self, + Coordinates._construct_direct( + coords=variables, indexes=dict(self.xindexes), dims=dict(self.sizes) + ), ) @@ -915,9 +921,7 @@ def drop_indexed_coords( return Coordinates._construct_direct(coords=new_variables, indexes=new_indexes) -def assert_coordinate_consistent( - obj: T_DataArray | Dataset, coords: Mapping[Any, Variable] -) -> None: +def assert_coordinate_consistent(obj: T_Xarray, coords: Mapping[Any, Variable]) -> None: """Make sure the dimension coordinate of obj is consistent with coords. obj: DataArray or Dataset diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 791aad5cd17..73464c07c82 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -4,7 +4,15 @@ import warnings from collections.abc import Hashable, Iterable, Mapping, MutableMapping, Sequence from os import PathLike -from typing import TYPE_CHECKING, Any, Callable, Literal, NoReturn, cast, overload +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Generic, + Literal, + NoReturn, + overload, +) import numpy as np import pandas as pd @@ -41,6 +49,7 @@ from xarray.core.indexing import is_fancy_indexer, map_index_queries from xarray.core.merge import PANDAS_TYPES, MergeError from xarray.core.options import OPTIONS, _get_keep_attrs +from xarray.core.types import DaCompatible, T_DataArray, T_DataArrayOrSet from xarray.core.utils import ( Default, HybridMappingProxy, @@ -100,8 +109,8 @@ QueryEngineOptions, QueryParserOptions, ReindexMethodOptions, + Self, SideOptions, - T_DataArray, T_Xarray, ) from xarray.core.weighted import DataArrayWeighted @@ -213,13 +222,13 @@ def _check_data_shape(data, coords, dims): return data -class _LocIndexer: +class _LocIndexer(Generic[T_DataArray]): __slots__ = ("data_array",) - def __init__(self, data_array: DataArray): + def __init__(self, data_array: T_DataArray): self.data_array = data_array - def __getitem__(self, key) -> DataArray: + def __getitem__(self, key) -> T_DataArray: if not utils.is_dict_like(key): # expand the indexer so we can handle Ellipsis labels = indexing.expanded_indexer(key, self.data_array.ndim) @@ -462,12 +471,12 @@ def __init__( @classmethod def _construct_direct( - cls: type[T_DataArray], + cls, variable: Variable, coords: dict[Any, Variable], name: Hashable, indexes: dict[Hashable, Index], - ) -> T_DataArray: + ) -> Self: """Shortcut around __init__ for internal use when we want to skip costly validation """ @@ -480,12 +489,12 @@ def _construct_direct( return obj def _replace( - self: T_DataArray, + self, variable: Variable | None = None, coords=None, name: Hashable | None | Default = _default, indexes=None, - ) -> T_DataArray: + ) -> Self: if variable is None: variable = self.variable if coords is None: @@ -497,10 +506,10 @@ def _replace( return type(self)(variable, coords, name=name, indexes=indexes, fastpath=True) def _replace_maybe_drop_dims( - self: T_DataArray, + self, variable: Variable, name: Hashable | None | Default = _default, - ) -> T_DataArray: + ) -> Self: if variable.dims == self.dims and variable.shape == self.shape: coords = self._coords.copy() indexes = self._indexes @@ -522,12 +531,12 @@ def _replace_maybe_drop_dims( return self._replace(variable, coords, name, indexes=indexes) def _overwrite_indexes( - self: T_DataArray, + self, indexes: Mapping[Any, Index], variables: Mapping[Any, Variable] | None = None, drop_coords: list[Hashable] | None = None, rename_dims: Mapping[Any, Any] | None = None, - ) -> T_DataArray: + ) -> Self: """Maybe replace indexes and their corresponding coordinates.""" if not indexes: return self @@ -560,8 +569,8 @@ def _to_temp_dataset(self) -> Dataset: return self._to_dataset_whole(name=_THIS_ARRAY, shallow_copy=False) def _from_temp_dataset( - self: T_DataArray, dataset: Dataset, name: Hashable | None | Default = _default - ) -> T_DataArray: + self, dataset: Dataset, name: Hashable | None | Default = _default + ) -> Self: variable = dataset._variables.pop(_THIS_ARRAY) coords = dataset._variables indexes = dataset._indexes @@ -773,7 +782,7 @@ def to_numpy(self) -> np.ndarray: """ return self.variable.to_numpy() - def as_numpy(self: T_DataArray) -> T_DataArray: + def as_numpy(self) -> Self: """ Coerces wrapped data and coordinates into numpy arrays, returning a DataArray. @@ -828,7 +837,7 @@ def _item_key_to_dict(self, key: Any) -> Mapping[Hashable, Any]: key = indexing.expanded_indexer(key, self.ndim) return dict(zip(self.dims, key)) - def _getitem_coord(self: T_DataArray, key: Any) -> T_DataArray: + def _getitem_coord(self, key: Any) -> Self: from xarray.core.dataset import _get_virtual_variable try: @@ -839,7 +848,7 @@ def _getitem_coord(self: T_DataArray, key: Any) -> T_DataArray: return self._replace_maybe_drop_dims(var, name=key) - def __getitem__(self: T_DataArray, key: Any) -> T_DataArray: + def __getitem__(self, key: Any) -> Self: if isinstance(key, str): return self._getitem_coord(key) else: @@ -909,7 +918,7 @@ def encoding(self) -> dict[Any, Any]: def encoding(self, value: Mapping[Any, Any]) -> None: self.variable.encoding = dict(value) - def reset_encoding(self: T_DataArray) -> T_DataArray: + def reset_encoding(self) -> Self: """Return a new DataArray without encoding on the array or any attached coords.""" ds = self._to_temp_dataset().reset_encoding() @@ -949,7 +958,7 @@ def coords(self) -> DataArrayCoordinates: @overload def reset_coords( - self: T_DataArray, + self, names: Dims = None, drop: Literal[False] = False, ) -> Dataset: @@ -957,18 +966,18 @@ def reset_coords( @overload def reset_coords( - self: T_DataArray, + self, names: Dims = None, *, drop: Literal[True], - ) -> T_DataArray: + ) -> Self: ... def reset_coords( - self: T_DataArray, + self, names: Dims = None, drop: bool = False, - ) -> T_DataArray | Dataset: + ) -> Self | Dataset: """Given names of coordinates, reset them to become variables. Parameters @@ -1080,15 +1089,15 @@ def __dask_postpersist__(self): func, args = self._to_temp_dataset().__dask_postpersist__() return self._dask_finalize, (self.name, func) + args - @staticmethod - def _dask_finalize(results, name, func, *args, **kwargs) -> DataArray: + @classmethod + def _dask_finalize(cls, results, name, func, *args, **kwargs) -> Self: ds = func(results, *args, **kwargs) variable = ds._variables.pop(_THIS_ARRAY) coords = ds._variables indexes = ds._indexes - return DataArray(variable, coords, name=name, indexes=indexes, fastpath=True) + return cls(variable, coords, name=name, indexes=indexes, fastpath=True) - def load(self: T_DataArray, **kwargs) -> T_DataArray: + def load(self, **kwargs) -> Self: """Manually trigger loading of this array's data from disk or a remote source into memory and return this array. @@ -1112,7 +1121,7 @@ def load(self: T_DataArray, **kwargs) -> T_DataArray: self._coords = new._coords return self - def compute(self: T_DataArray, **kwargs) -> T_DataArray: + def compute(self, **kwargs) -> Self: """Manually trigger loading of this array's data from disk or a remote source into memory and return a new array. The original is left unaltered. @@ -1134,7 +1143,7 @@ def compute(self: T_DataArray, **kwargs) -> T_DataArray: new = self.copy(deep=False) return new.load(**kwargs) - def persist(self: T_DataArray, **kwargs) -> T_DataArray: + def persist(self, **kwargs) -> Self: """Trigger computation in constituent dask arrays This keeps them as dask arrays but encourages them to keep data in @@ -1153,7 +1162,7 @@ def persist(self: T_DataArray, **kwargs) -> T_DataArray: ds = self._to_temp_dataset().persist(**kwargs) return self._from_temp_dataset(ds) - def copy(self: T_DataArray, deep: bool = True, data: Any = None) -> T_DataArray: + def copy(self, deep: bool = True, data: Any = None) -> Self: """Returns a copy of this array. If `deep=True`, a deep copy is made of the data array. @@ -1224,11 +1233,11 @@ def copy(self: T_DataArray, deep: bool = True, data: Any = None) -> T_DataArray: return self._copy(deep=deep, data=data) def _copy( - self: T_DataArray, + self, deep: bool = True, data: Any = None, memo: dict[int, Any] | None = None, - ) -> T_DataArray: + ) -> Self: variable = self.variable._copy(deep=deep, data=data, memo=memo) indexes, index_vars = self.xindexes.copy_indexes(deep=deep) @@ -1241,12 +1250,10 @@ def _copy( return self._replace(variable, coords, indexes=indexes) - def __copy__(self: T_DataArray) -> T_DataArray: + def __copy__(self) -> Self: return self._copy(deep=False) - def __deepcopy__( - self: T_DataArray, memo: dict[int, Any] | None = None - ) -> T_DataArray: + def __deepcopy__(self, memo: dict[int, Any] | None = None) -> Self: return self._copy(deep=True, memo=memo) # mutable objects should not be Hashable @@ -1287,7 +1294,7 @@ def chunksizes(self) -> Mapping[Any, tuple[int, ...]]: return get_chunksizes(all_variables) def chunk( - self: T_DataArray, + self, chunks: ( int | Literal["auto"] @@ -1302,7 +1309,7 @@ def chunk( chunked_array_type: str | ChunkManagerEntrypoint | None = None, from_array_kwargs=None, **chunks_kwargs: Any, - ) -> T_DataArray: + ) -> Self: """Coerce this array's data into a dask arrays with the given chunks. If this variable is a non-dask array, it will be converted to dask @@ -1380,12 +1387,12 @@ def chunk( return self._from_temp_dataset(ds) def isel( - self: T_DataArray, + self, indexers: Mapping[Any, Any] | None = None, drop: bool = False, missing_dims: ErrorOptionsWithWarn = "raise", **indexers_kwargs: Any, - ) -> T_DataArray: + ) -> Self: """Return a new DataArray whose data is given by selecting indexes along the specified dimension(s). @@ -1471,13 +1478,13 @@ def isel( return self._replace(variable=variable, coords=coords, indexes=indexes) def sel( - self: T_DataArray, + self, indexers: Mapping[Any, Any] | None = None, method: str | None = None, tolerance=None, drop: bool = False, **indexers_kwargs: Any, - ) -> T_DataArray: + ) -> Self: """Return a new DataArray whose data is given by selecting index labels along the specified dimension(s). @@ -1590,10 +1597,10 @@ def sel( return self._from_temp_dataset(ds) def head( - self: T_DataArray, + self, indexers: Mapping[Any, int] | int | None = None, **indexers_kwargs: Any, - ) -> T_DataArray: + ) -> Self: """Return a new DataArray whose data is given by the the first `n` values along the specified dimension(s). Default `n` = 5 @@ -1633,10 +1640,10 @@ def head( return self._from_temp_dataset(ds) def tail( - self: T_DataArray, + self, indexers: Mapping[Any, int] | int | None = None, **indexers_kwargs: Any, - ) -> T_DataArray: + ) -> Self: """Return a new DataArray whose data is given by the the last `n` values along the specified dimension(s). Default `n` = 5 @@ -1680,10 +1687,10 @@ def tail( return self._from_temp_dataset(ds) def thin( - self: T_DataArray, + self, indexers: Mapping[Any, int] | int | None = None, **indexers_kwargs: Any, - ) -> T_DataArray: + ) -> Self: """Return a new DataArray whose data is given by each `n` value along the specified dimension(s). @@ -1730,10 +1737,10 @@ def thin( return self._from_temp_dataset(ds) def broadcast_like( - self: T_DataArray, - other: DataArray | Dataset, + self, + other: T_DataArrayOrSet, exclude: Iterable[Hashable] | None = None, - ) -> T_DataArray: + ) -> Self: """Broadcast this DataArray against another Dataset or DataArray. This is equivalent to xr.broadcast(other, self)[1] @@ -1803,12 +1810,10 @@ def broadcast_like( dims_map, common_coords = _get_broadcast_dims_map_common_coords(args, exclude) - return _broadcast_helper( - cast("T_DataArray", args[1]), exclude, dims_map, common_coords - ) + return _broadcast_helper(args[1], exclude, dims_map, common_coords) def _reindex_callback( - self: T_DataArray, + self, aligner: alignment.Aligner, dim_pos_indexers: dict[Hashable, Any], variables: dict[Hashable, Variable], @@ -1816,7 +1821,7 @@ def _reindex_callback( fill_value: Any, exclude_dims: frozenset[Hashable], exclude_vars: frozenset[Hashable], - ) -> T_DataArray: + ) -> Self: """Callback called from ``Aligner`` to create a new reindexed DataArray.""" if isinstance(fill_value, dict): @@ -1843,13 +1848,13 @@ def _reindex_callback( return da def reindex_like( - self: T_DataArray, - other: DataArray | Dataset, + self, + other: T_DataArrayOrSet, method: ReindexMethodOptions = None, tolerance: int | float | Iterable[int | float] | None = None, copy: bool = True, fill_value=dtypes.NA, - ) -> T_DataArray: + ) -> Self: """Conform this object onto the indexes of another object, filling in missing values with ``fill_value``. The default fill value is NaN. @@ -2013,14 +2018,14 @@ def reindex_like( ) def reindex( - self: T_DataArray, + self, indexers: Mapping[Any, Any] | None = None, method: ReindexMethodOptions = None, tolerance: float | Iterable[float] | None = None, copy: bool = True, fill_value=dtypes.NA, **indexers_kwargs: Any, - ) -> T_DataArray: + ) -> Self: """Conform this object onto the indexes of another object, filling in missing values with ``fill_value``. The default fill value is NaN. @@ -2104,13 +2109,13 @@ def reindex( ) def interp( - self: T_DataArray, + self, coords: Mapping[Any, Any] | None = None, method: InterpOptions = "linear", assume_sorted: bool = False, kwargs: Mapping[str, Any] | None = None, **coords_kwargs: Any, - ) -> T_DataArray: + ) -> Self: """Interpolate a DataArray onto new coordinates Performs univariate or multivariate interpolation of a DataArray onto @@ -2247,12 +2252,12 @@ def interp( return self._from_temp_dataset(ds) def interp_like( - self: T_DataArray, - other: DataArray | Dataset, + self, + other: T_Xarray, method: InterpOptions = "linear", assume_sorted: bool = False, kwargs: Mapping[str, Any] | None = None, - ) -> T_DataArray: + ) -> Self: """Interpolate this object onto the coordinates of another object, filling out of range values with NaN. @@ -2369,13 +2374,11 @@ def interp_like( ) return self._from_temp_dataset(ds) - # change type of self and return to T_DataArray once - # https://github.com/python/mypy/issues/12846 is resolved def rename( self, new_name_or_name_dict: Hashable | Mapping[Any, Hashable] | None = None, **names: Hashable, - ) -> DataArray: + ) -> Self: """Returns a new DataArray with renamed coordinates, dimensions or a new name. Parameters @@ -2416,10 +2419,10 @@ def rename( return self._replace(name=new_name_or_name_dict) def swap_dims( - self: T_DataArray, + self, dims_dict: Mapping[Any, Hashable] | None = None, **dims_kwargs, - ) -> T_DataArray: + ) -> Self: """Returns a new DataArray with swapped dimensions. Parameters @@ -2474,14 +2477,12 @@ def swap_dims( ds = self._to_temp_dataset().swap_dims(dims_dict) return self._from_temp_dataset(ds) - # change type of self and return to T_DataArray once - # https://github.com/python/mypy/issues/12846 is resolved def expand_dims( self, dim: None | Hashable | Sequence[Hashable] | Mapping[Any, Any] = None, axis: None | int | Sequence[int] = None, **dim_kwargs: Any, - ) -> DataArray: + ) -> Self: """Return a new object with an additional axis (or axes) inserted at the corresponding position in the array shape. The new object is a view into the underlying array, not a copy. @@ -2570,14 +2571,12 @@ def expand_dims( ds = self._to_temp_dataset().expand_dims(dim, axis) return self._from_temp_dataset(ds) - # change type of self and return to T_DataArray once - # https://github.com/python/mypy/issues/12846 is resolved def set_index( self, indexes: Mapping[Any, Hashable | Sequence[Hashable]] | None = None, append: bool = False, **indexes_kwargs: Hashable | Sequence[Hashable], - ) -> DataArray: + ) -> Self: """Set DataArray (multi-)indexes using one or more existing coordinates. @@ -2635,13 +2634,11 @@ def set_index( ds = self._to_temp_dataset().set_index(indexes, append=append, **indexes_kwargs) return self._from_temp_dataset(ds) - # change type of self and return to T_DataArray once - # https://github.com/python/mypy/issues/12846 is resolved def reset_index( self, dims_or_levels: Hashable | Sequence[Hashable], drop: bool = False, - ) -> DataArray: + ) -> Self: """Reset the specified index(es) or multi-index level(s). This legacy method is specific to pandas (multi-)indexes and @@ -2675,11 +2672,11 @@ def reset_index( return self._from_temp_dataset(ds) def set_xindex( - self: T_DataArray, + self, coord_names: str | Sequence[Hashable], index_cls: type[Index] | None = None, **options, - ) -> T_DataArray: + ) -> Self: """Set a new, Xarray-compatible index from one or more existing coordinate(s). @@ -2704,10 +2701,10 @@ def set_xindex( return self._from_temp_dataset(ds) def reorder_levels( - self: T_DataArray, + self, dim_order: Mapping[Any, Sequence[int | Hashable]] | None = None, **dim_order_kwargs: Sequence[int | Hashable], - ) -> T_DataArray: + ) -> Self: """Rearrange index levels using input order. Parameters @@ -2730,12 +2727,12 @@ def reorder_levels( return self._from_temp_dataset(ds) def stack( - self: T_DataArray, + self, dimensions: Mapping[Any, Sequence[Hashable]] | None = None, create_index: bool | None = True, index_cls: type[Index] = PandasMultiIndex, **dimensions_kwargs: Sequence[Hashable], - ) -> T_DataArray: + ) -> Self: """ Stack any number of existing dimensions into a single new dimension. @@ -2802,14 +2799,12 @@ def stack( ) return self._from_temp_dataset(ds) - # change type of self and return to T_DataArray once - # https://github.com/python/mypy/issues/12846 is resolved def unstack( self, dim: Dims = None, fill_value: Any = dtypes.NA, sparse: bool = False, - ) -> DataArray: + ) -> Self: """ Unstack existing dimensions corresponding to MultiIndexes into multiple new dimensions. @@ -2933,11 +2928,11 @@ def to_unstacked_dataset(self, dim: Hashable, level: int | Hashable = 0) -> Data return Dataset(data_dict) def transpose( - self: T_DataArray, + self, *dims: Hashable, transpose_coords: bool = True, missing_dims: ErrorOptionsWithWarn = "raise", - ) -> T_DataArray: + ) -> Self: """Return a new DataArray object with transposed dimensions. Parameters @@ -2983,17 +2978,15 @@ def transpose( return self._replace(variable) @property - def T(self: T_DataArray) -> T_DataArray: + def T(self) -> Self: return self.transpose() - # change type of self and return to T_DataArray once - # https://github.com/python/mypy/issues/12846 is resolved def drop_vars( self, names: Hashable | Iterable[Hashable], *, errors: ErrorOptions = "raise", - ) -> DataArray: + ) -> Self: """Returns an array with dropped variables. Parameters @@ -3054,11 +3047,11 @@ def drop_vars( return self._from_temp_dataset(ds) def drop_indexes( - self: T_DataArray, + self, coord_names: Hashable | Iterable[Hashable], *, errors: ErrorOptions = "raise", - ) -> T_DataArray: + ) -> Self: """Drop the indexes assigned to the given coordinates. Parameters @@ -3079,13 +3072,13 @@ def drop_indexes( return self._from_temp_dataset(ds) def drop( - self: T_DataArray, + self, labels: Mapping[Any, Any] | None = None, dim: Hashable | None = None, *, errors: ErrorOptions = "raise", **labels_kwargs, - ) -> T_DataArray: + ) -> Self: """Backward compatible method based on `drop_vars` and `drop_sel` Using either `drop_vars` or `drop_sel` is encouraged @@ -3099,12 +3092,12 @@ def drop( return self._from_temp_dataset(ds) def drop_sel( - self: T_DataArray, + self, labels: Mapping[Any, Any] | None = None, *, errors: ErrorOptions = "raise", **labels_kwargs, - ) -> T_DataArray: + ) -> Self: """Drop index labels from this DataArray. Parameters @@ -3167,8 +3160,8 @@ def drop_sel( return self._from_temp_dataset(ds) def drop_isel( - self: T_DataArray, indexers: Mapping[Any, Any] | None = None, **indexers_kwargs - ) -> T_DataArray: + self, indexers: Mapping[Any, Any] | None = None, **indexers_kwargs + ) -> Self: """Drop index positions from this DataArray. Parameters @@ -3218,11 +3211,11 @@ def drop_isel( return self._from_temp_dataset(dataset) def dropna( - self: T_DataArray, + self, dim: Hashable, how: Literal["any", "all"] = "any", thresh: int | None = None, - ) -> T_DataArray: + ) -> Self: """Returns a new array with dropped labels for missing values along the provided dimension. @@ -3293,7 +3286,7 @@ def dropna( ds = self._to_temp_dataset().dropna(dim, how=how, thresh=thresh) return self._from_temp_dataset(ds) - def fillna(self: T_DataArray, value: Any) -> T_DataArray: + def fillna(self, value: Any) -> Self: """Fill missing values in this object. This operation follows the normal broadcasting and alignment rules that @@ -3356,7 +3349,7 @@ def fillna(self: T_DataArray, value: Any) -> T_DataArray: return out def interpolate_na( - self: T_DataArray, + self, dim: Hashable | None = None, method: InterpOptions = "linear", limit: int | None = None, @@ -3372,7 +3365,7 @@ def interpolate_na( ) = None, keep_attrs: bool | None = None, **kwargs: Any, - ) -> T_DataArray: + ) -> Self: """Fill in NaNs by interpolating according to different methods. Parameters @@ -3479,9 +3472,7 @@ def interpolate_na( **kwargs, ) - def ffill( - self: T_DataArray, dim: Hashable, limit: int | None = None - ) -> T_DataArray: + def ffill(self, dim: Hashable, limit: int | None = None) -> Self: """Fill NaN values by propagating values forward *Requires bottleneck.* @@ -3565,9 +3556,7 @@ def ffill( return ffill(self, dim, limit=limit) - def bfill( - self: T_DataArray, dim: Hashable, limit: int | None = None - ) -> T_DataArray: + def bfill(self, dim: Hashable, limit: int | None = None) -> Self: """Fill NaN values by propagating values backward *Requires bottleneck.* @@ -3651,7 +3640,7 @@ def bfill( return bfill(self, dim, limit=limit) - def combine_first(self: T_DataArray, other: T_DataArray) -> T_DataArray: + def combine_first(self, other: Self) -> Self: """Combine two DataArray objects, with union of coordinates. This operation follows the normal broadcasting and alignment rules of @@ -3670,7 +3659,7 @@ def combine_first(self: T_DataArray, other: T_DataArray) -> T_DataArray: return ops.fillna(self, other, join="outer") def reduce( - self: T_DataArray, + self, func: Callable[..., Any], dim: Dims = None, *, @@ -3678,7 +3667,7 @@ def reduce( keep_attrs: bool | None = None, keepdims: bool = False, **kwargs: Any, - ) -> T_DataArray: + ) -> Self: """Reduce this array by applying `func` along some dimension(s). Parameters @@ -3716,7 +3705,7 @@ def reduce( var = self.variable.reduce(func, dim, axis, keep_attrs, keepdims, **kwargs) return self._replace_maybe_drop_dims(var) - def to_pandas(self) -> DataArray | pd.Series | pd.DataFrame: + def to_pandas(self) -> Self | pd.Series | pd.DataFrame: """Convert this array into a pandas object with the same shape. The type of the returned object depends on the number of DataArray @@ -4270,7 +4259,7 @@ def to_dict( return d @classmethod - def from_dict(cls: type[T_DataArray], d: Mapping[str, Any]) -> T_DataArray: + def from_dict(cls, d: Mapping[str, Any]) -> Self: """Convert a dictionary into an xarray.DataArray Parameters @@ -4387,7 +4376,7 @@ def to_cdms2(self) -> cdms2_Variable: return to_cdms2(self) @classmethod - def from_cdms2(cls, variable: cdms2_Variable) -> DataArray: + def from_cdms2(cls, variable: cdms2_Variable) -> Self: """Convert a cdms2.Variable into an xarray.DataArray .. deprecated:: 2023.06.0 @@ -4414,13 +4403,13 @@ def to_iris(self) -> iris_Cube: return to_iris(self) @classmethod - def from_iris(cls, cube: iris_Cube) -> DataArray: + def from_iris(cls, cube: iris_Cube) -> Self: """Convert a iris.cube.Cube into an xarray.DataArray""" from xarray.convert import from_iris return from_iris(cube) - def _all_compat(self: T_DataArray, other: T_DataArray, compat_str: str) -> bool: + def _all_compat(self, other: Self, compat_str: str) -> bool: """Helper function for equals, broadcast_equals, and identical""" def compat(x, y): @@ -4430,7 +4419,7 @@ def compat(x, y): self, other ) - def broadcast_equals(self: T_DataArray, other: T_DataArray) -> bool: + def broadcast_equals(self, other: Self) -> bool: """Two DataArrays are broadcast equal if they are equal after broadcasting them against each other such that they have the same dimensions. @@ -4479,7 +4468,7 @@ def broadcast_equals(self: T_DataArray, other: T_DataArray) -> bool: except (TypeError, AttributeError): return False - def equals(self: T_DataArray, other: T_DataArray) -> bool: + def equals(self, other: Self) -> bool: """True if two DataArrays have the same dimensions, coordinates and values; otherwise False. @@ -4541,7 +4530,7 @@ def equals(self: T_DataArray, other: T_DataArray) -> bool: except (TypeError, AttributeError): return False - def identical(self: T_DataArray, other: T_DataArray) -> bool: + def identical(self, other: Self) -> bool: """Like equals, but also checks the array name and attributes, and attributes on all coordinates. @@ -4608,19 +4597,19 @@ def _result_name(self, other: Any = None) -> Hashable | None: else: return None - def __array_wrap__(self: T_DataArray, obj, context=None) -> T_DataArray: + def __array_wrap__(self, obj, context=None) -> Self: new_var = self.variable.__array_wrap__(obj, context) return self._replace(new_var) - def __matmul__(self: T_DataArray, obj: T_DataArray) -> T_DataArray: + def __matmul__(self, obj: T_Xarray) -> T_Xarray: return self.dot(obj) - def __rmatmul__(self: T_DataArray, other: T_DataArray) -> T_DataArray: + def __rmatmul__(self, other: T_Xarray) -> T_Xarray: # currently somewhat duplicative, as only other DataArrays are # compatible with matmul return computation.dot(other, self) - def _unary_op(self: T_DataArray, f: Callable, *args, **kwargs) -> T_DataArray: + def _unary_op(self, f: Callable, *args, **kwargs) -> Self: keep_attrs = kwargs.pop("keep_attrs", None) if keep_attrs is None: keep_attrs = _get_keep_attrs(default=True) @@ -4636,18 +4625,18 @@ def _unary_op(self: T_DataArray, f: Callable, *args, **kwargs) -> T_DataArray: return da def _binary_op( - self: T_DataArray, - other: Any, + self, + other: T_Xarray, f: Callable, reflexive: bool = False, - ) -> T_DataArray: + ) -> T_Xarray: from xarray.core.groupby import GroupBy if isinstance(other, (Dataset, GroupBy)): return NotImplemented if isinstance(other, DataArray): align_type = OPTIONS["arithmetic_join"] - self, other = align(self, other, join=align_type, copy=False) # type: ignore + self, other = align(self, other, join=align_type, copy=False) other_variable = getattr(other, "variable", other) other_coords = getattr(other, "coords", None) @@ -4661,7 +4650,7 @@ def _binary_op( return self._replace(variable, coords, name, indexes=indexes) - def _inplace_binary_op(self: T_DataArray, other: Any, f: Callable) -> T_DataArray: + def _inplace_binary_op(self, other: DaCompatible, f: Callable) -> Self: from xarray.core.groupby import GroupBy if isinstance(other, GroupBy): @@ -4721,11 +4710,11 @@ def _title_for_slice(self, truncate: int = 50) -> str: return title def diff( - self: T_DataArray, + self, dim: Hashable, n: int = 1, label: Literal["upper", "lower"] = "upper", - ) -> T_DataArray: + ) -> Self: """Calculate the n-th order discrete difference along given axis. Parameters @@ -4771,11 +4760,11 @@ def diff( return self._from_temp_dataset(ds) def shift( - self: T_DataArray, + self, shifts: Mapping[Any, int] | None = None, fill_value: Any = dtypes.NA, **shifts_kwargs: int, - ) -> T_DataArray: + ) -> Self: """Shift this DataArray by an offset along one or more dimensions. Only the data is moved; coordinates stay in place. This is consistent @@ -4821,11 +4810,11 @@ def shift( return self._replace(variable=variable) def roll( - self: T_DataArray, + self, shifts: Mapping[Hashable, int] | None = None, roll_coords: bool = False, **shifts_kwargs: int, - ) -> T_DataArray: + ) -> Self: """Roll this array by an offset along one or more dimensions. Unlike shift, roll treats the given dimensions as periodic, so will not @@ -4870,7 +4859,7 @@ def roll( return self._from_temp_dataset(ds) @property - def real(self: T_DataArray) -> T_DataArray: + def real(self) -> Self: """ The real part of the array. @@ -4881,7 +4870,7 @@ def real(self: T_DataArray) -> T_DataArray: return self._replace(self.variable.real) @property - def imag(self: T_DataArray) -> T_DataArray: + def imag(self) -> Self: """ The imaginary part of the array. @@ -4892,10 +4881,10 @@ def imag(self: T_DataArray) -> T_DataArray: return self._replace(self.variable.imag) def dot( - self: T_DataArray, - other: T_DataArray, + self, + other: T_Xarray, dims: Dims = None, - ) -> T_DataArray: + ) -> T_Xarray: """Perform dot product of two DataArrays along their shared dims. Equivalent to taking taking tensordot over all shared dims. @@ -4945,13 +4934,11 @@ def dot( return computation.dot(self, other, dims=dims) - # change type of self and return to T_DataArray once - # https://github.com/python/mypy/issues/12846 is resolved def sortby( self, variables: Hashable | DataArray | Sequence[Hashable | DataArray], ascending: bool = True, - ) -> DataArray: + ) -> Self: """Sort object by labels or values (along an axis). Sorts the dataarray, either along specified dimensions, @@ -5012,14 +4999,14 @@ def sortby( return self._from_temp_dataset(ds) def quantile( - self: T_DataArray, + self, q: ArrayLike, dim: Dims = None, method: QuantileMethods = "linear", keep_attrs: bool | None = None, skipna: bool | None = None, interpolation: QuantileMethods | None = None, - ) -> T_DataArray: + ) -> Self: """Compute the qth quantile of the data along the specified dimension. Returns the qth quantiles(s) of the array elements. @@ -5130,11 +5117,11 @@ def quantile( return self._from_temp_dataset(ds) def rank( - self: T_DataArray, + self, dim: Hashable, pct: bool = False, keep_attrs: bool | None = None, - ) -> T_DataArray: + ) -> Self: """Ranks the data. Equal values are assigned a rank that is the average of the ranks that @@ -5174,11 +5161,11 @@ def rank( return self._from_temp_dataset(ds) def differentiate( - self: T_DataArray, + self, coord: Hashable, edge_order: Literal[1, 2] = 1, datetime_unit: DatetimeUnitOptions = None, - ) -> T_DataArray: + ) -> Self: """ Differentiate the array with the second order accurate central differences. @@ -5236,13 +5223,11 @@ def differentiate( ds = self._to_temp_dataset().differentiate(coord, edge_order, datetime_unit) return self._from_temp_dataset(ds) - # change type of self and return to T_DataArray once - # https://github.com/python/mypy/issues/12846 is resolved def integrate( self, coord: Hashable | Sequence[Hashable] = None, datetime_unit: DatetimeUnitOptions = None, - ) -> DataArray: + ) -> Self: """Integrate along the given coordinate using the trapezoidal rule. .. note:: @@ -5292,13 +5277,11 @@ def integrate( ds = self._to_temp_dataset().integrate(coord, datetime_unit) return self._from_temp_dataset(ds) - # change type of self and return to T_DataArray once - # https://github.com/python/mypy/issues/12846 is resolved def cumulative_integrate( self, coord: Hashable | Sequence[Hashable] = None, datetime_unit: DatetimeUnitOptions = None, - ) -> DataArray: + ) -> Self: """Integrate cumulatively along the given coordinate using the trapezoidal rule. .. note:: @@ -5356,7 +5339,7 @@ def cumulative_integrate( ds = self._to_temp_dataset().cumulative_integrate(coord, datetime_unit) return self._from_temp_dataset(ds) - def unify_chunks(self) -> DataArray: + def unify_chunks(self) -> Self: """Unify chunk size along all chunked dimensions of this DataArray. Returns @@ -5541,7 +5524,7 @@ def polyfit( ) def pad( - self: T_DataArray, + self, pad_width: Mapping[Any, int | tuple[int, int]] | None = None, mode: PadModeOptions = "constant", stat_length: int @@ -5556,7 +5539,7 @@ def pad( reflect_type: PadReflectOptions = None, keep_attrs: bool | None = None, **pad_width_kwargs: Any, - ) -> T_DataArray: + ) -> Self: """Pad this array along one or more dimensions. .. warning:: @@ -5714,7 +5697,7 @@ def idxmin( skipna: bool | None = None, fill_value: Any = dtypes.NA, keep_attrs: bool | None = None, - ) -> DataArray: + ) -> Self: """Return the coordinate label of the minimum value along a dimension. Returns a new `DataArray` named after the dimension with the values of @@ -5810,7 +5793,7 @@ def idxmax( skipna: bool | None = None, fill_value: Any = dtypes.NA, keep_attrs: bool | None = None, - ) -> DataArray: + ) -> Self: """Return the coordinate label of the maximum value along a dimension. Returns a new `DataArray` named after the dimension with the values of @@ -5900,15 +5883,13 @@ def idxmax( keep_attrs=keep_attrs, ) - # change type of self and return to T_DataArray once - # https://github.com/python/mypy/issues/12846 is resolved def argmin( self, dim: Dims = None, axis: int | None = None, keep_attrs: bool | None = None, skipna: bool | None = None, - ) -> DataArray | dict[Hashable, DataArray]: + ) -> Self | dict[Hashable, Self]: """Index or indices of the minimum of the DataArray over one or more dimensions. If a sequence is passed to 'dim', then result returned as dict of DataArrays, @@ -6002,15 +5983,13 @@ def argmin( else: return self._replace_maybe_drop_dims(result) - # change type of self and return to T_DataArray once - # https://github.com/python/mypy/issues/12846 is resolved def argmax( self, dim: Dims = None, axis: int | None = None, keep_attrs: bool | None = None, skipna: bool | None = None, - ) -> DataArray | dict[Hashable, DataArray]: + ) -> Self | dict[Hashable, Self]: """Index or indices of the maximum of the DataArray over one or more dimensions. If a sequence is passed to 'dim', then result returned as dict of DataArrays, @@ -6352,10 +6331,10 @@ def curvefit( ) def drop_duplicates( - self: T_DataArray, + self, dim: Hashable | Iterable[Hashable], keep: Literal["first", "last", False] = "first", - ) -> T_DataArray: + ) -> Self: """Returns a new DataArray with duplicate dimension values removed. Parameters @@ -6437,7 +6416,7 @@ def convert_calendar( align_on: str | None = None, missing: Any | None = None, use_cftime: bool | None = None, - ) -> DataArray: + ) -> Self: """Convert the DataArray to another calendar. Only converts the individual timestamps, does not modify any data except @@ -6557,7 +6536,7 @@ def interp_calendar( self, target: pd.DatetimeIndex | CFTimeIndex | DataArray, dim: str = "time", - ) -> DataArray: + ) -> Self: """Interpolates the DataArray to another calendar based on decimal year measure. Each timestamp in `source` and `target` are first converted to their decimal diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 48e25f7e1c7..9d771f0390c 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -93,7 +93,7 @@ is_duck_array, is_duck_dask_array, ) -from xarray.core.types import QuantileMethods, T_Dataset +from xarray.core.types import QuantileMethods, Self, T_DataArrayOrSet, T_Dataset from xarray.core.utils import ( Default, Frozen, @@ -698,11 +698,11 @@ def __init__( # TODO: dirty workaround for mypy 1.5 error with inherited DatasetOpsMixin vs. Mapping # related to https://github.com/python/mypy/issues/9319? - def __eq__(self: T_Dataset, other: DsCompatible) -> T_Dataset: # type: ignore[override] + def __eq__(self, other: DsCompatible) -> Self: # type: ignore[override] return super().__eq__(other) @classmethod - def load_store(cls: type[T_Dataset], store, decoder=None) -> T_Dataset: + def load_store(cls, store, decoder=None) -> Self: """Create a new dataset from the contents of a backends.*DataStore object """ @@ -746,7 +746,7 @@ def encoding(self) -> dict[Any, Any]: def encoding(self, value: Mapping[Any, Any]) -> None: self._encoding = dict(value) - def reset_encoding(self: T_Dataset) -> T_Dataset: + def reset_encoding(self) -> Self: """Return a new Dataset without encoding on the dataset or any of its variables/coords.""" variables = {k: v.reset_encoding() for k, v in self.variables.items()} @@ -802,7 +802,7 @@ def dtypes(self) -> Frozen[Hashable, np.dtype]: } ) - def load(self: T_Dataset, **kwargs) -> T_Dataset: + def load(self, **kwargs) -> Self: """Manually trigger loading and/or computation of this dataset's data from disk or a remote source into memory and return this dataset. Unlike compute, the original dataset is modified and returned. @@ -902,7 +902,7 @@ def __dask_postcompute__(self): def __dask_postpersist__(self): return self._dask_postpersist, () - def _dask_postcompute(self: T_Dataset, results: Iterable[Variable]) -> T_Dataset: + def _dask_postcompute(self, results: Iterable[Variable]) -> Self: import dask variables = {} @@ -925,8 +925,8 @@ def _dask_postcompute(self: T_Dataset, results: Iterable[Variable]) -> T_Dataset ) def _dask_postpersist( - self: T_Dataset, dsk: Mapping, *, rename: Mapping[str, str] | None = None - ) -> T_Dataset: + self, dsk: Mapping, *, rename: Mapping[str, str] | None = None + ) -> Self: from dask import is_dask_collection from dask.highlevelgraph import HighLevelGraph from dask.optimization import cull @@ -975,7 +975,7 @@ def _dask_postpersist( self._close, ) - def compute(self: T_Dataset, **kwargs) -> T_Dataset: + def compute(self, **kwargs) -> Self: """Manually trigger loading and/or computation of this dataset's data from disk or a remote source into memory and return a new dataset. Unlike load, the original dataset is left unaltered. @@ -997,7 +997,7 @@ def compute(self: T_Dataset, **kwargs) -> T_Dataset: new = self.copy(deep=False) return new.load(**kwargs) - def _persist_inplace(self: T_Dataset, **kwargs) -> T_Dataset: + def _persist_inplace(self, **kwargs) -> Self: """Persist all Dask arrays in memory""" # access .data to coerce everything to numpy or dask arrays lazy_data = { @@ -1014,7 +1014,7 @@ def _persist_inplace(self: T_Dataset, **kwargs) -> T_Dataset: return self - def persist(self: T_Dataset, **kwargs) -> T_Dataset: + def persist(self, **kwargs) -> Self: """Trigger computation, keeping data as dask arrays This operation can be used to trigger computation on underlying dask @@ -1037,7 +1037,7 @@ def persist(self: T_Dataset, **kwargs) -> T_Dataset: @classmethod def _construct_direct( - cls: type[T_Dataset], + cls, variables: dict[Any, Variable], coord_names: set[Hashable], dims: dict[Any, int] | None = None, @@ -1045,7 +1045,7 @@ def _construct_direct( indexes: dict[Any, Index] | None = None, encoding: dict | None = None, close: Callable[[], None] | None = None, - ) -> T_Dataset: + ) -> Self: """Shortcut around __init__ for internal use when we want to skip costly validation """ @@ -1064,7 +1064,7 @@ def _construct_direct( return obj def _replace( - self: T_Dataset, + self, variables: dict[Hashable, Variable] | None = None, coord_names: set[Hashable] | None = None, dims: dict[Any, int] | None = None, @@ -1072,7 +1072,7 @@ def _replace( indexes: dict[Hashable, Index] | None = None, encoding: dict | None | Default = _default, inplace: bool = False, - ) -> T_Dataset: + ) -> Self: """Fastpath constructor for internal use. Returns an object with optionally with replaced attributes. @@ -1114,13 +1114,13 @@ def _replace( return obj def _replace_with_new_dims( - self: T_Dataset, + self, variables: dict[Hashable, Variable], coord_names: set | None = None, attrs: dict[Hashable, Any] | None | Default = _default, indexes: dict[Hashable, Index] | None = None, inplace: bool = False, - ) -> T_Dataset: + ) -> Self: """Replace variables with recalculated dimensions.""" dims = calculate_dimensions(variables) return self._replace( @@ -1128,13 +1128,13 @@ def _replace_with_new_dims( ) def _replace_vars_and_dims( - self: T_Dataset, + self, variables: dict[Hashable, Variable], coord_names: set | None = None, dims: dict[Hashable, int] | None = None, attrs: dict[Hashable, Any] | None | Default = _default, inplace: bool = False, - ) -> T_Dataset: + ) -> Self: """Deprecated version of _replace_with_new_dims(). Unlike _replace_with_new_dims(), this method always recalculates @@ -1147,13 +1147,13 @@ def _replace_vars_and_dims( ) def _overwrite_indexes( - self: T_Dataset, + self, indexes: Mapping[Hashable, Index], variables: Mapping[Hashable, Variable] | None = None, drop_variables: list[Hashable] | None = None, drop_indexes: list[Hashable] | None = None, rename_dims: Mapping[Hashable, Hashable] | None = None, - ) -> T_Dataset: + ) -> Self: """Maybe replace indexes. This function may do a lot more depending on index query @@ -1221,8 +1221,8 @@ def _overwrite_indexes( return replaced def copy( - self: T_Dataset, deep: bool = False, data: Mapping[Any, ArrayLike] | None = None - ) -> T_Dataset: + self, deep: bool = False, data: Mapping[Any, ArrayLike] | None = None + ) -> Self: """Returns a copy of this dataset. If `deep=True`, a deep copy is made of each of the component variables. @@ -1322,11 +1322,11 @@ def copy( return self._copy(deep=deep, data=data) def _copy( - self: T_Dataset, + self, deep: bool = False, data: Mapping[Any, ArrayLike] | None = None, memo: dict[int, Any] | None = None, - ) -> T_Dataset: + ) -> Self: if data is None: data = {} elif not utils.is_dict_like(data): @@ -1364,13 +1364,13 @@ def _copy( return self._replace(variables, indexes=indexes, attrs=attrs, encoding=encoding) - def __copy__(self: T_Dataset) -> T_Dataset: + def __copy__(self) -> Self: return self._copy(deep=False) - def __deepcopy__(self: T_Dataset, memo: dict[int, Any] | None = None) -> T_Dataset: + def __deepcopy__(self, memo: dict[int, Any] | None = None) -> Self: return self._copy(deep=True, memo=memo) - def as_numpy(self: T_Dataset) -> T_Dataset: + def as_numpy(self) -> Self: """ Coerces wrapped data and coordinates into numpy arrays, returning a Dataset. @@ -1382,7 +1382,7 @@ def as_numpy(self: T_Dataset) -> T_Dataset: numpy_variables = {k: v.as_numpy() for k, v in self.variables.items()} return self._replace(variables=numpy_variables) - def _copy_listed(self: T_Dataset, names: Iterable[Hashable]) -> T_Dataset: + def _copy_listed(self, names: Iterable[Hashable]) -> Self: """Create a new Dataset with the listed variables from this dataset and the all relevant coordinates. Skips all validation. """ @@ -1495,7 +1495,7 @@ def nbytes(self) -> int: return sum(v.nbytes for v in self.variables.values()) @property - def loc(self: T_Dataset) -> _LocIndexer[T_Dataset]: + def loc(self) -> _LocIndexer[Self]: """Attribute for location based indexing. Only supports __getitem__, and only when the key is a dict of the form {dim: labels}. """ @@ -1507,12 +1507,12 @@ def __getitem__(self, key: Hashable) -> DataArray: # Mapping is Iterable @overload - def __getitem__(self: T_Dataset, key: Iterable[Hashable]) -> T_Dataset: + def __getitem__(self, key: Iterable[Hashable]) -> Self: ... def __getitem__( - self: T_Dataset, key: Mapping[Any, Any] | Hashable | Iterable[Hashable] - ) -> T_Dataset | DataArray: + self, key: Mapping[Any, Any] | Hashable | Iterable[Hashable] + ) -> Self | DataArray: """Access variables or coordinates of this dataset as a :py:class:`~xarray.DataArray` or a subset of variables or a indexed dataset. @@ -1677,7 +1677,7 @@ def __delitem__(self, key: Hashable) -> None: # https://github.com/python/mypy/issues/4266 __hash__ = None # type: ignore[assignment] - def _all_compat(self, other: Dataset, compat_str: str) -> bool: + def _all_compat(self, other: Self, compat_str: str) -> bool: """Helper function for equals and identical""" # some stores (e.g., scipy) do not seem to preserve order, so don't @@ -1689,7 +1689,7 @@ def compat(x: Variable, y: Variable) -> bool: self._variables, other._variables, compat=compat ) - def broadcast_equals(self, other: Dataset) -> bool: + def broadcast_equals(self, other: Self) -> bool: """Two Datasets are broadcast equal if they are equal after broadcasting all variables against each other. @@ -1756,7 +1756,7 @@ def broadcast_equals(self, other: Dataset) -> bool: except (TypeError, AttributeError): return False - def equals(self, other: Dataset) -> bool: + def equals(self, other: Self) -> bool: """Two Datasets are equal if they have matching variables and coordinates, all of which are equal. @@ -1837,7 +1837,7 @@ def equals(self, other: Dataset) -> bool: except (TypeError, AttributeError): return False - def identical(self, other: Dataset) -> bool: + def identical(self, other: Self) -> bool: """Like equals, but also checks all dataset attributes and the attributes on all variables and coordinates. @@ -1950,7 +1950,7 @@ def data_vars(self) -> DataVariables: """Dictionary of DataArray objects corresponding to data variables""" return DataVariables(self) - def set_coords(self: T_Dataset, names: Hashable | Iterable[Hashable]) -> T_Dataset: + def set_coords(self, names: Hashable | Iterable[Hashable]) -> Self: """Given names of one or more variables, set them as coordinates Parameters @@ -2008,10 +2008,10 @@ def set_coords(self: T_Dataset, names: Hashable | Iterable[Hashable]) -> T_Datas return obj def reset_coords( - self: T_Dataset, + self, names: Dims = None, drop: bool = False, - ) -> T_Dataset: + ) -> Self: """Given names of coordinates, reset them to become variables Parameters @@ -2562,7 +2562,7 @@ def chunksizes(self) -> Mapping[Hashable, tuple[int, ...]]: return get_chunksizes(self.variables.values()) def chunk( - self: T_Dataset, + self, chunks: ( int | Literal["auto"] | Mapping[Any, None | int | str | tuple[int, ...]] ) = {}, # {} even though it's technically unsafe, is being used intentionally here (#4667) @@ -2573,7 +2573,7 @@ def chunk( chunked_array_type: str | ChunkManagerEntrypoint | None = None, from_array_kwargs=None, **chunks_kwargs: None | int | str | tuple[int, ...], - ) -> T_Dataset: + ) -> Self: """Coerce all arrays in this dataset into dask arrays with the given chunks. @@ -2767,12 +2767,12 @@ def _get_indexers_coords_and_indexes(self, indexers): return attached_coords, attached_indexes def isel( - self: T_Dataset, + self, indexers: Mapping[Any, Any] | None = None, drop: bool = False, missing_dims: ErrorOptionsWithWarn = "raise", **indexers_kwargs: Any, - ) -> T_Dataset: + ) -> Self: """Returns a new dataset with each array indexed along the specified dimension(s). @@ -2915,12 +2915,12 @@ def isel( ) def _isel_fancy( - self: T_Dataset, + self, indexers: Mapping[Any, Any], *, drop: bool, missing_dims: ErrorOptionsWithWarn = "raise", - ) -> T_Dataset: + ) -> Self: valid_indexers = dict(self._validate_indexers(indexers, missing_dims)) variables: dict[Hashable, Variable] = {} @@ -2956,13 +2956,13 @@ def _isel_fancy( return self._replace_with_new_dims(variables, coord_names, indexes=indexes) def sel( - self: T_Dataset, + self, indexers: Mapping[Any, Any] | None = None, method: str | None = None, tolerance: int | float | Iterable[int | float] | None = None, drop: bool = False, **indexers_kwargs: Any, - ) -> T_Dataset: + ) -> Self: """Returns a new dataset with each array indexed by tick labels along the specified dimension(s). @@ -3042,10 +3042,10 @@ def sel( return result._overwrite_indexes(*query_results.as_tuple()[1:]) def head( - self: T_Dataset, + self, indexers: Mapping[Any, int] | int | None = None, **indexers_kwargs: Any, - ) -> T_Dataset: + ) -> Self: """Returns a new dataset with the first `n` values of each array for the specified dimension(s). @@ -3132,10 +3132,10 @@ def head( return self.isel(indexers_slices) def tail( - self: T_Dataset, + self, indexers: Mapping[Any, int] | int | None = None, **indexers_kwargs: Any, - ) -> T_Dataset: + ) -> Self: """Returns a new dataset with the last `n` values of each array for the specified dimension(s). @@ -3223,10 +3223,10 @@ def tail( return self.isel(indexers_slices) def thin( - self: T_Dataset, + self, indexers: Mapping[Any, int] | int | None = None, **indexers_kwargs: Any, - ) -> T_Dataset: + ) -> Self: """Returns a new dataset with each array indexed along every `n`-th value for the specified dimension(s) @@ -3308,10 +3308,10 @@ def thin( return self.isel(indexers_slices) def broadcast_like( - self: T_Dataset, - other: Dataset | DataArray, + self, + other: T_DataArrayOrSet, exclude: Iterable[Hashable] | None = None, - ) -> T_Dataset: + ) -> Self: """Broadcast this DataArray against another Dataset or DataArray. This is equivalent to xr.broadcast(other, self)[1] @@ -3331,12 +3331,10 @@ def broadcast_like( dims_map, common_coords = _get_broadcast_dims_map_common_coords(args, exclude) - return _broadcast_helper( - cast("T_Dataset", args[1]), exclude, dims_map, common_coords - ) + return _broadcast_helper(args[1], exclude, dims_map, common_coords) def _reindex_callback( - self: T_Dataset, + self, aligner: alignment.Aligner, dim_pos_indexers: dict[Hashable, Any], variables: dict[Hashable, Variable], @@ -3344,7 +3342,7 @@ def _reindex_callback( fill_value: Any, exclude_dims: frozenset[Hashable], exclude_vars: frozenset[Hashable], - ) -> T_Dataset: + ) -> Self: """Callback called from ``Aligner`` to create a new reindexed Dataset.""" new_variables = variables.copy() @@ -3397,13 +3395,13 @@ def _reindex_callback( return reindexed def reindex_like( - self: T_Dataset, - other: Dataset | DataArray, + self, + other: T_Xarray, method: ReindexMethodOptions = None, tolerance: int | float | Iterable[int | float] | None = None, copy: bool = True, fill_value: Any = xrdtypes.NA, - ) -> T_Dataset: + ) -> Self: """Conform this object onto the indexes of another object, filling in missing values with ``fill_value``. The default fill value is NaN. @@ -3463,14 +3461,14 @@ def reindex_like( ) def reindex( - self: T_Dataset, + self, indexers: Mapping[Any, Any] | None = None, method: ReindexMethodOptions = None, tolerance: int | float | Iterable[int | float] | None = None, copy: bool = True, fill_value: Any = xrdtypes.NA, **indexers_kwargs: Any, - ) -> T_Dataset: + ) -> Self: """Conform this object onto a new set of indexes, filling in missing values with ``fill_value``. The default fill value is NaN. @@ -3679,7 +3677,7 @@ def reindex( ) def _reindex( - self: T_Dataset, + self, indexers: Mapping[Any, Any] | None = None, method: str | None = None, tolerance: int | float | Iterable[int | float] | None = None, @@ -3687,7 +3685,7 @@ def _reindex( fill_value: Any = xrdtypes.NA, sparse: bool = False, **indexers_kwargs: Any, - ) -> T_Dataset: + ) -> Self: """ Same as reindex but supports sparse option. """ @@ -3703,14 +3701,14 @@ def _reindex( ) def interp( - self: T_Dataset, + self, coords: Mapping[Any, Any] | None = None, method: InterpOptions = "linear", assume_sorted: bool = False, kwargs: Mapping[str, Any] | None = None, method_non_numeric: str = "nearest", **coords_kwargs: Any, - ) -> T_Dataset: + ) -> Self: """Interpolate a Dataset onto new coordinates Performs univariate or multivariate interpolation of a Dataset onto @@ -3983,12 +3981,12 @@ def _validate_interp_indexer(x, new_x): def interp_like( self, - other: Dataset | DataArray, + other: T_Xarray, method: InterpOptions = "linear", assume_sorted: bool = False, kwargs: Mapping[str, Any] | None = None, method_non_numeric: str = "nearest", - ) -> Dataset: + ) -> Self: """Interpolate this object onto the coordinates of another object, filling the out of range values with NaN. @@ -4138,10 +4136,10 @@ def _rename_all( return variables, coord_names, dims, indexes def _rename( - self: T_Dataset, + self, name_dict: Mapping[Any, Hashable] | None = None, **names: Hashable, - ) -> T_Dataset: + ) -> Self: """Also used internally by DataArray so that the warning (if any) is raised at the right stack level. """ @@ -4180,10 +4178,10 @@ def _rename( return self._replace(variables, coord_names, dims=dims, indexes=indexes) def rename( - self: T_Dataset, + self, name_dict: Mapping[Any, Hashable] | None = None, **names: Hashable, - ) -> T_Dataset: + ) -> Self: """Returns a new object with renamed variables, coordinates and dimensions. Parameters @@ -4210,10 +4208,10 @@ def rename( return self._rename(name_dict=name_dict, **names) def rename_dims( - self: T_Dataset, + self, dims_dict: Mapping[Any, Hashable] | None = None, **dims: Hashable, - ) -> T_Dataset: + ) -> Self: """Returns a new object with renamed dimensions only. Parameters @@ -4257,10 +4255,10 @@ def rename_dims( return self._replace(variables, coord_names, dims=sizes, indexes=indexes) def rename_vars( - self: T_Dataset, + self, name_dict: Mapping[Any, Hashable] | None = None, **names: Hashable, - ) -> T_Dataset: + ) -> Self: """Returns a new object with renamed variables including coordinates Parameters @@ -4297,8 +4295,8 @@ def rename_vars( return self._replace(variables, coord_names, dims=dims, indexes=indexes) def swap_dims( - self: T_Dataset, dims_dict: Mapping[Any, Hashable] | None = None, **dims_kwargs - ) -> T_Dataset: + self, dims_dict: Mapping[Any, Hashable] | None = None, **dims_kwargs + ) -> Self: """Returns a new object with swapped dimensions. Parameters @@ -4401,14 +4399,12 @@ def swap_dims( return self._replace_with_new_dims(variables, coord_names, indexes=indexes) - # change type of self and return to T_Dataset once - # https://github.com/python/mypy/issues/12846 is resolved def expand_dims( self, dim: None | Hashable | Sequence[Hashable] | Mapping[Any, Any] = None, axis: None | int | Sequence[int] = None, **dim_kwargs: Any, - ) -> Dataset: + ) -> Self: """Return a new object with an additional axis (or axes) inserted at the corresponding position in the array shape. The new object is a view into the underlying array, not a copy. @@ -4598,14 +4594,12 @@ def expand_dims( variables, coord_names=coord_names, indexes=indexes ) - # change type of self and return to T_Dataset once - # https://github.com/python/mypy/issues/12846 is resolved def set_index( self, indexes: Mapping[Any, Hashable | Sequence[Hashable]] | None = None, append: bool = False, **indexes_kwargs: Hashable | Sequence[Hashable], - ) -> Dataset: + ) -> Self: """Set Dataset (multi-)indexes using one or more existing coordinates or variables. @@ -4766,10 +4760,10 @@ def set_index( ) def reset_index( - self: T_Dataset, + self, dims_or_levels: Hashable | Sequence[Hashable], drop: bool = False, - ) -> T_Dataset: + ) -> Self: """Reset the specified index(es) or multi-index level(s). This legacy method is specific to pandas (multi-)indexes and @@ -4877,11 +4871,11 @@ def drop_or_convert(var_names): ) def set_xindex( - self: T_Dataset, + self, coord_names: str | Sequence[Hashable], index_cls: type[Index] | None = None, **options, - ) -> T_Dataset: + ) -> Self: """Set a new, Xarray-compatible index from one or more existing coordinate(s). @@ -4989,10 +4983,10 @@ def set_xindex( ) def reorder_levels( - self: T_Dataset, + self, dim_order: Mapping[Any, Sequence[int | Hashable]] | None = None, **dim_order_kwargs: Sequence[int | Hashable], - ) -> T_Dataset: + ) -> Self: """Rearrange index levels using input order. Parameters @@ -5093,12 +5087,12 @@ def _get_stack_index( return stack_index, stack_coords def _stack_once( - self: T_Dataset, + self, dims: Sequence[Hashable | ellipsis], new_dim: Hashable, index_cls: type[Index], create_index: bool | None = True, - ) -> T_Dataset: + ) -> Self: if dims == ...: raise ValueError("Please use [...] for dims, rather than just ...") if ... in dims: @@ -5152,12 +5146,12 @@ def _stack_once( ) def stack( - self: T_Dataset, + self, dimensions: Mapping[Any, Sequence[Hashable | ellipsis]] | None = None, create_index: bool | None = True, index_cls: type[Index] = PandasMultiIndex, **dimensions_kwargs: Sequence[Hashable | ellipsis], - ) -> T_Dataset: + ) -> Self: """ Stack any number of existing dimensions into a single new dimension. @@ -5312,12 +5306,12 @@ def stack_dataarray(da): return data_array def _unstack_once( - self: T_Dataset, + self, dim: Hashable, index_and_vars: tuple[Index, dict[Hashable, Variable]], fill_value, sparse: bool = False, - ) -> T_Dataset: + ) -> Self: index, index_vars = index_and_vars variables: dict[Hashable, Variable] = {} indexes = {k: v for k, v in self._indexes.items() if k != dim} @@ -5352,12 +5346,12 @@ def _unstack_once( ) def _unstack_full_reindex( - self: T_Dataset, + self, dim: Hashable, index_and_vars: tuple[Index, dict[Hashable, Variable]], fill_value, sparse: bool, - ) -> T_Dataset: + ) -> Self: index, index_vars = index_and_vars variables: dict[Hashable, Variable] = {} indexes = {k: v for k, v in self._indexes.items() if k != dim} @@ -5403,11 +5397,11 @@ def _unstack_full_reindex( ) def unstack( - self: T_Dataset, + self, dim: Dims = None, fill_value: Any = xrdtypes.NA, sparse: bool = False, - ) -> T_Dataset: + ) -> Self: """ Unstack existing dimensions corresponding to MultiIndexes into multiple new dimensions. @@ -5504,7 +5498,7 @@ def unstack( result = result._unstack_once(d, stacked_indexes[d], fill_value, sparse) return result - def update(self: T_Dataset, other: CoercibleMapping) -> T_Dataset: + def update(self, other: CoercibleMapping) -> Self: """Update this dataset's variables with those from another dataset. Just like :py:meth:`dict.update` this is a in-place operation. @@ -5544,14 +5538,14 @@ def update(self: T_Dataset, other: CoercibleMapping) -> T_Dataset: return self._replace(inplace=True, **merge_result._asdict()) def merge( - self: T_Dataset, + self, other: CoercibleMapping | DataArray, overwrite_vars: Hashable | Iterable[Hashable] = frozenset(), compat: CompatOptions = "no_conflicts", join: JoinOptions = "outer", fill_value: Any = xrdtypes.NA, combine_attrs: CombineAttrsOptions = "override", - ) -> T_Dataset: + ) -> Self: """Merge the arrays of two datasets into a single dataset. This method generally does not allow for overriding data, with the @@ -5655,11 +5649,11 @@ def _assert_all_in_dataset( ) def drop_vars( - self: T_Dataset, + self, names: Hashable | Iterable[Hashable], *, errors: ErrorOptions = "raise", - ) -> T_Dataset: + ) -> Self: """Drop variables from this dataset. Parameters @@ -5801,11 +5795,11 @@ def drop_vars( ) def drop_indexes( - self: T_Dataset, + self, coord_names: Hashable | Iterable[Hashable], *, errors: ErrorOptions = "raise", - ) -> T_Dataset: + ) -> Self: """Drop the indexes assigned to the given coordinates. Parameters @@ -5857,13 +5851,13 @@ def drop_indexes( return self._replace(variables=variables, indexes=indexes) def drop( - self: T_Dataset, + self, labels=None, dim=None, *, errors: ErrorOptions = "raise", **labels_kwargs, - ) -> T_Dataset: + ) -> Self: """Backward compatible method based on `drop_vars` and `drop_sel` Using either `drop_vars` or `drop_sel` is encouraged @@ -5913,8 +5907,8 @@ def drop( return self.drop_sel(labels, errors=errors) def drop_sel( - self: T_Dataset, labels=None, *, errors: ErrorOptions = "raise", **labels_kwargs - ) -> T_Dataset: + self, labels=None, *, errors: ErrorOptions = "raise", **labels_kwargs + ) -> Self: """Drop index labels from this dataset. Parameters @@ -5983,7 +5977,7 @@ def drop_sel( ds = ds.loc[{dim: new_index}] return ds - def drop_isel(self: T_Dataset, indexers=None, **indexers_kwargs) -> T_Dataset: + def drop_isel(self, indexers=None, **indexers_kwargs) -> Self: """Drop index positions from this Dataset. Parameters @@ -6049,11 +6043,11 @@ def drop_isel(self: T_Dataset, indexers=None, **indexers_kwargs) -> T_Dataset: return ds def drop_dims( - self: T_Dataset, + self, drop_dims: str | Iterable[Hashable], *, errors: ErrorOptions = "raise", - ) -> T_Dataset: + ) -> Self: """Drop dimensions and associated variables from this dataset. Parameters @@ -6090,10 +6084,10 @@ def drop_dims( return self.drop_vars(drop_vars) def transpose( - self: T_Dataset, + self, *dims: Hashable, missing_dims: ErrorOptionsWithWarn = "raise", - ) -> T_Dataset: + ) -> Self: """Return a new Dataset object with all array dimensions transposed. Although the order of dimensions on each array will change, the dataset @@ -6146,12 +6140,12 @@ def transpose( return ds def dropna( - self: T_Dataset, + self, dim: Hashable, how: Literal["any", "all"] = "any", thresh: int | None = None, subset: Iterable[Hashable] | None = None, - ) -> T_Dataset: + ) -> Self: """Returns a new dataset with dropped labels for missing values along the provided dimension. @@ -6273,7 +6267,7 @@ def dropna( return self.isel({dim: mask}) - def fillna(self: T_Dataset, value: Any) -> T_Dataset: + def fillna(self, value: Any) -> Self: """Fill missing values in this object. This operation follows the normal broadcasting and alignment rules that @@ -6354,7 +6348,7 @@ def fillna(self: T_Dataset, value: Any) -> T_Dataset: return out def interpolate_na( - self: T_Dataset, + self, dim: Hashable | None = None, method: InterpOptions = "linear", limit: int | None = None, @@ -6363,7 +6357,7 @@ def interpolate_na( int | float | str | pd.Timedelta | np.timedelta64 | datetime.timedelta ) = None, **kwargs: Any, - ) -> T_Dataset: + ) -> Self: """Fill in NaNs by interpolating according to different methods. Parameters @@ -6493,7 +6487,7 @@ def interpolate_na( ) return new - def ffill(self: T_Dataset, dim: Hashable, limit: int | None = None) -> T_Dataset: + def ffill(self, dim: Hashable, limit: int | None = None) -> Self: """Fill NaN values by propagating values forward *Requires bottleneck.* @@ -6557,7 +6551,7 @@ def ffill(self: T_Dataset, dim: Hashable, limit: int | None = None) -> T_Dataset new = _apply_over_vars_with_dim(ffill, self, dim=dim, limit=limit) return new - def bfill(self: T_Dataset, dim: Hashable, limit: int | None = None) -> T_Dataset: + def bfill(self, dim: Hashable, limit: int | None = None) -> Self: """Fill NaN values by propagating values backward *Requires bottleneck.* @@ -6622,7 +6616,7 @@ def bfill(self: T_Dataset, dim: Hashable, limit: int | None = None) -> T_Dataset new = _apply_over_vars_with_dim(bfill, self, dim=dim, limit=limit) return new - def combine_first(self: T_Dataset, other: T_Dataset) -> T_Dataset: + def combine_first(self, other: Self) -> Self: """Combine two Datasets, default to data_vars of self. The new coordinates follow the normal broadcasting and alignment rules @@ -6642,7 +6636,7 @@ def combine_first(self: T_Dataset, other: T_Dataset) -> T_Dataset: return out def reduce( - self: T_Dataset, + self, func: Callable, dim: Dims = None, *, @@ -6650,7 +6644,7 @@ def reduce( keepdims: bool = False, numeric_only: bool = False, **kwargs: Any, - ) -> T_Dataset: + ) -> Self: """Reduce this dataset by applying `func` along some dimension(s). Parameters @@ -6775,12 +6769,12 @@ def reduce( ) def map( - self: T_Dataset, + self, func: Callable, keep_attrs: bool | None = None, args: Iterable[Any] = (), **kwargs: Any, - ) -> T_Dataset: + ) -> Self: """Apply a function to each data variable in this dataset Parameters @@ -6835,12 +6829,12 @@ def map( return type(self)(variables, attrs=attrs) def apply( - self: T_Dataset, + self, func: Callable, keep_attrs: bool | None = None, args: Iterable[Any] = (), **kwargs: Any, - ) -> T_Dataset: + ) -> Self: """ Backward compatible implementation of ``map`` @@ -6856,10 +6850,10 @@ def apply( return self.map(func, keep_attrs, args, **kwargs) def assign( - self: T_Dataset, + self, variables: Mapping[Any, Any] | None = None, **variables_kwargs: Any, - ) -> T_Dataset: + ) -> Self: """Assign new data variables to a Dataset, returning a new object with all the original variables in addition to the new ones. @@ -7164,9 +7158,7 @@ def _set_numpy_data_from_dataframe( self[name] = (dims, data) @classmethod - def from_dataframe( - cls: type[T_Dataset], dataframe: pd.DataFrame, sparse: bool = False - ) -> T_Dataset: + def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> Self: """Convert a pandas.DataFrame into an xarray.Dataset Each column will be converted into an independent variable in the @@ -7380,7 +7372,7 @@ def to_dict( return d @classmethod - def from_dict(cls: type[T_Dataset], d: Mapping[Any, Any]) -> T_Dataset: + def from_dict(cls, d: Mapping[Any, Any]) -> Self: """Convert a dictionary into an xarray.Dataset. Parameters @@ -7470,7 +7462,7 @@ def from_dict(cls: type[T_Dataset], d: Mapping[Any, Any]) -> T_Dataset: return obj - def _unary_op(self: T_Dataset, f, *args, **kwargs) -> T_Dataset: + def _unary_op(self, f, *args, **kwargs) -> Self: variables = {} keep_attrs = kwargs.pop("keep_attrs", None) if keep_attrs is None: @@ -7501,7 +7493,7 @@ def _binary_op(self, other, f, reflexive=False, join=None) -> Dataset: ds.attrs = self.attrs return ds - def _inplace_binary_op(self: T_Dataset, other, f) -> T_Dataset: + def _inplace_binary_op(self, other, f) -> Self: from xarray.core.dataarray import DataArray from xarray.core.groupby import GroupBy @@ -7576,11 +7568,11 @@ def _copy_attrs_from(self, other): self.variables[v].attrs = other.variables[v].attrs def diff( - self: T_Dataset, + self, dim: Hashable, n: int = 1, label: Literal["upper", "lower"] = "upper", - ) -> T_Dataset: + ) -> Self: """Calculate the n-th order discrete difference along given axis. Parameters @@ -7663,11 +7655,11 @@ def diff( return difference def shift( - self: T_Dataset, + self, shifts: Mapping[Any, int] | None = None, fill_value: Any = xrdtypes.NA, **shifts_kwargs: int, - ) -> T_Dataset: + ) -> Self: """Shift this dataset by an offset along one or more dimensions. Only data variables are moved; coordinates stay in place. This is @@ -7734,11 +7726,11 @@ def shift( return self._replace(variables) def roll( - self: T_Dataset, + self, shifts: Mapping[Any, int] | None = None, roll_coords: bool = False, **shifts_kwargs: int, - ) -> T_Dataset: + ) -> Self: """Roll this dataset by an offset along one or more dimensions. Unlike shift, roll treats the given dimensions as periodic, so will not @@ -7820,10 +7812,10 @@ def roll( return self._replace(variables, indexes=indexes) def sortby( - self: T_Dataset, + self, variables: Hashable | DataArray | list[Hashable | DataArray], ascending: bool = True, - ) -> T_Dataset: + ) -> Self: """ Sort object by labels or values (along an axis). @@ -7890,7 +7882,7 @@ def sortby( variables = variables arrays = [v if isinstance(v, DataArray) else self[v] for v in variables] aligned_vars = align(self, *arrays, join="left") # type: ignore[type-var] - aligned_self: T_Dataset = aligned_vars[0] # type: ignore[assignment] + aligned_self = cast(Self, aligned_vars[0]) aligned_other_vars: tuple[DataArray, ...] = aligned_vars[1:] # type: ignore[assignment] vars_by_dim = defaultdict(list) for data_array in aligned_other_vars: @@ -7906,7 +7898,7 @@ def sortby( return aligned_self.isel(indices) def quantile( - self: T_Dataset, + self, q: ArrayLike, dim: Dims = None, method: QuantileMethods = "linear", @@ -7914,7 +7906,7 @@ def quantile( keep_attrs: bool | None = None, skipna: bool | None = None, interpolation: QuantileMethods | None = None, - ) -> T_Dataset: + ) -> Self: """Compute the qth quantile of the data along the specified dimension. Returns the qth quantiles(s) of the array elements for each variable @@ -8084,11 +8076,11 @@ def quantile( return new.assign_coords(quantile=q) def rank( - self: T_Dataset, + self, dim: Hashable, pct: bool = False, keep_attrs: bool | None = None, - ) -> T_Dataset: + ) -> Self: """Ranks the data. Equal values are assigned a rank that is the average of the ranks that @@ -8142,11 +8134,11 @@ def rank( return self._replace(variables, coord_names, attrs=attrs) def differentiate( - self: T_Dataset, + self, coord: Hashable, edge_order: Literal[1, 2] = 1, datetime_unit: DatetimeUnitOptions | None = None, - ) -> T_Dataset: + ) -> Self: """ Differentiate with the second order accurate central differences. @@ -8214,10 +8206,10 @@ def differentiate( return self._replace(variables) def integrate( - self: T_Dataset, + self, coord: Hashable | Sequence[Hashable], datetime_unit: DatetimeUnitOptions = None, - ) -> T_Dataset: + ) -> Self: """Integrate along the given coordinate using the trapezoidal rule. .. note:: @@ -8333,10 +8325,10 @@ def _integrate_one(self, coord, datetime_unit=None, cumulative=False): ) def cumulative_integrate( - self: T_Dataset, + self, coord: Hashable | Sequence[Hashable], datetime_unit: DatetimeUnitOptions = None, - ) -> T_Dataset: + ) -> Self: """Integrate along the given coordinate using the trapezoidal rule. .. note:: @@ -8408,7 +8400,7 @@ def cumulative_integrate( return result @property - def real(self: T_Dataset) -> T_Dataset: + def real(self) -> Self: """ The real part of each data variable. @@ -8419,7 +8411,7 @@ def real(self: T_Dataset) -> T_Dataset: return self.map(lambda x: x.real, keep_attrs=True) @property - def imag(self: T_Dataset) -> T_Dataset: + def imag(self) -> Self: """ The imaginary part of each data variable. @@ -8431,7 +8423,7 @@ def imag(self: T_Dataset) -> T_Dataset: plot = utils.UncachedAccessor(DatasetPlotAccessor) - def filter_by_attrs(self: T_Dataset, **kwargs) -> T_Dataset: + def filter_by_attrs(self, **kwargs) -> Self: """Returns a ``Dataset`` with variables that match specific conditions. Can pass in ``key=value`` or ``key=callable``. A Dataset is returned @@ -8526,7 +8518,7 @@ def filter_by_attrs(self: T_Dataset, **kwargs) -> T_Dataset: selection.append(var_name) return self[selection] - def unify_chunks(self: T_Dataset) -> T_Dataset: + def unify_chunks(self) -> Self: """Unify chunk size along all chunked dimensions of this Dataset. Returns @@ -8648,7 +8640,7 @@ def map_blocks( return map_blocks(func, self, args, kwargs, template) def polyfit( - self: T_Dataset, + self, dim: Hashable, deg: int, skipna: bool | None = None, @@ -8656,7 +8648,7 @@ def polyfit( w: Hashable | Any = None, full: bool = False, cov: bool | Literal["unscaled"] = False, - ) -> T_Dataset: + ) -> Self: """ Least squares polynomial fit. @@ -8844,7 +8836,7 @@ def polyfit( return type(self)(data_vars=variables, attrs=self.attrs.copy()) def pad( - self: T_Dataset, + self, pad_width: Mapping[Any, int | tuple[int, int]] | None = None, mode: PadModeOptions = "constant", stat_length: int @@ -8858,7 +8850,7 @@ def pad( reflect_type: PadReflectOptions = None, keep_attrs: bool | None = None, **pad_width_kwargs: Any, - ) -> T_Dataset: + ) -> Self: """Pad this dataset along one or more dimensions. .. warning:: @@ -9030,12 +9022,12 @@ def pad( return self._replace_with_new_dims(variables, indexes=indexes, attrs=attrs) def idxmin( - self: T_Dataset, + self, dim: Hashable | None = None, skipna: bool | None = None, fill_value: Any = xrdtypes.NA, keep_attrs: bool | None = None, - ) -> T_Dataset: + ) -> Self: """Return the coordinate label of the minimum value along a dimension. Returns a new `Dataset` named after the dimension with the values of @@ -9127,12 +9119,12 @@ def idxmin( ) def idxmax( - self: T_Dataset, + self, dim: Hashable | None = None, skipna: bool | None = None, fill_value: Any = xrdtypes.NA, keep_attrs: bool | None = None, - ) -> T_Dataset: + ) -> Self: """Return the coordinate label of the maximum value along a dimension. Returns a new `Dataset` named after the dimension with the values of @@ -9223,7 +9215,7 @@ def idxmax( ) ) - def argmin(self: T_Dataset, dim: Hashable | None = None, **kwargs) -> T_Dataset: + def argmin(self, dim: Hashable | None = None, **kwargs) -> Self: """Indices of the minima of the member variables. If there are multiple minima, the indices of the first one found will be @@ -9326,7 +9318,7 @@ def argmin(self: T_Dataset, dim: Hashable | None = None, **kwargs) -> T_Dataset: "Dataset.argmin() with a sequence or ... for dim" ) - def argmax(self: T_Dataset, dim: Hashable | None = None, **kwargs) -> T_Dataset: + def argmax(self, dim: Hashable | None = None, **kwargs) -> Self: """Indices of the maxima of the member variables. If there are multiple maxima, the indices of the first one found will be @@ -9420,13 +9412,13 @@ def argmax(self: T_Dataset, dim: Hashable | None = None, **kwargs) -> T_Dataset: ) def query( - self: T_Dataset, + self, queries: Mapping[Any, Any] | None = None, parser: QueryParserOptions = "pandas", engine: QueryEngineOptions = None, missing_dims: ErrorOptionsWithWarn = "raise", **queries_kwargs: Any, - ) -> T_Dataset: + ) -> Self: """Return a new dataset with each array indexed along the specified dimension(s), where the indexers are given as strings containing Python expressions to be evaluated against the data variables in the @@ -9516,7 +9508,7 @@ def query( return self.isel(indexers, missing_dims=missing_dims) def curvefit( - self: T_Dataset, + self, coords: str | DataArray | Iterable[str | DataArray], func: Callable[..., Any], reduce_dims: Dims = None, @@ -9526,7 +9518,7 @@ def curvefit( param_names: Sequence[str] | None = None, errors: ErrorOptions = "raise", kwargs: dict[str, Any] | None = None, - ) -> T_Dataset: + ) -> Self: """ Curve fitting optimization for arbitrary functions. @@ -9750,10 +9742,10 @@ def _wrapper(Y, *args, **kwargs): return result def drop_duplicates( - self: T_Dataset, + self, dim: Hashable | Iterable[Hashable], keep: Literal["first", "last", False] = "first", - ) -> T_Dataset: + ) -> Self: """Returns a new Dataset with duplicate dimension values removed. Parameters @@ -9793,13 +9785,13 @@ def drop_duplicates( return self.isel(indexes) def convert_calendar( - self: T_Dataset, + self, calendar: CFCalendar, dim: Hashable = "time", align_on: Literal["date", "year", None] = None, missing: Any | None = None, use_cftime: bool | None = None, - ) -> T_Dataset: + ) -> Self: """Convert the Dataset to another calendar. Only converts the individual timestamps, does not modify any data except @@ -9916,10 +9908,10 @@ def convert_calendar( ) def interp_calendar( - self: T_Dataset, + self, target: pd.DatetimeIndex | CFTimeIndex | DataArray, dim: Hashable = "time", - ) -> T_Dataset: + ) -> Self: """Interpolates the Dataset to another calendar based on decimal year measure. Each timestamp in `source` and `target` are first converted to their decimal diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 1869cf2a0bd..2571b093450 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -66,8 +66,8 @@ PadModeOptions, PadReflectOptions, QuantileMethods, + Self, T_DuckArray, - T_Variable, ) NON_NANOSECOND_WARNING = ( @@ -420,7 +420,7 @@ def _in_memory(self): ) @property - def data(self: T_Variable): + def data(self): """ The Variable's data as an array. The underlying array type (e.g. dask, sparse, pint) is preserved. @@ -439,7 +439,7 @@ def data(self: T_Variable): return self.values @data.setter - def data(self: T_Variable, data: T_DuckArray | ArrayLike) -> None: + def data(self, data: T_DuckArray | ArrayLike) -> None: data = as_compatible_data(data) if data.shape != self.shape: # type: ignore[attr-defined] raise ValueError( @@ -449,7 +449,7 @@ def data(self: T_Variable, data: T_DuckArray | ArrayLike) -> None: self._data = data def astype( - self: T_Variable, + self, dtype, *, order=None, @@ -457,7 +457,7 @@ def astype( subok=None, copy=None, keep_attrs=True, - ) -> T_Variable: + ) -> Self: """ Copy of the Variable object, with data cast to a specified type. @@ -883,7 +883,7 @@ def _broadcast_indexes_vectorized(self, key): return out_dims, VectorizedIndexer(tuple(out_key)), new_order - def __getitem__(self: T_Variable, key) -> T_Variable: + def __getitem__(self, key) -> Self: """Return a new Variable object whose contents are consistent with getting the provided key from the underlying data. @@ -902,7 +902,7 @@ def __getitem__(self: T_Variable, key) -> T_Variable: data = np.moveaxis(data, range(len(new_order)), new_order) return self._finalize_indexing_result(dims, data) - def _finalize_indexing_result(self: T_Variable, dims, data) -> T_Variable: + def _finalize_indexing_result(self, dims, data) -> Self: """Used by IndexVariable to return IndexVariable objects when possible.""" return self._replace(dims=dims, data=data) @@ -1001,13 +1001,13 @@ def encoding(self, value): except ValueError: raise ValueError("encoding must be castable to a dictionary") - def reset_encoding(self: T_Variable) -> T_Variable: + def reset_encoding(self) -> Self: """Return a new Variable without encoding.""" return self._replace(encoding={}) def copy( - self: T_Variable, deep: bool = True, data: T_DuckArray | ArrayLike | None = None - ) -> T_Variable: + self, deep: bool = True, data: T_DuckArray | ArrayLike | None = None + ) -> Self: """Returns a copy of this object. If `deep=True`, the data array is loaded into memory and copied onto @@ -1066,11 +1066,11 @@ def copy( return self._copy(deep=deep, data=data) def _copy( - self: T_Variable, + self, deep: bool = True, data: T_DuckArray | ArrayLike | None = None, memo: dict[int, Any] | None = None, - ) -> T_Variable: + ) -> Self: if data is None: data_old = self._data @@ -1099,12 +1099,12 @@ def _copy( return self._replace(data=ndata, attrs=attrs, encoding=encoding) def _replace( - self: T_Variable, + self, dims=_default, data=_default, attrs=_default, encoding=_default, - ) -> T_Variable: + ) -> Self: if dims is _default: dims = copy.copy(self._dims) if data is _default: @@ -1115,12 +1115,10 @@ def _replace( encoding = copy.copy(self._encoding) return type(self)(dims, data, attrs, encoding, fastpath=True) - def __copy__(self: T_Variable) -> T_Variable: + def __copy__(self) -> Self: return self._copy(deep=False) - def __deepcopy__( - self: T_Variable, memo: dict[int, Any] | None = None - ) -> T_Variable: + def __deepcopy__(self, memo: dict[int, Any] | None = None) -> Self: return self._copy(deep=True, memo=memo) # mutable objects should not be hashable @@ -1179,7 +1177,7 @@ def chunk( chunked_array_type: str | ChunkManagerEntrypoint | None = None, from_array_kwargs=None, **chunks_kwargs: Any, - ) -> Variable: + ) -> Self: """Coerce this array's data into a dask array with the given chunks. If this variable is a non-dask array, it will be converted to dask @@ -1310,7 +1308,7 @@ def to_numpy(self) -> np.ndarray: return data - def as_numpy(self: T_Variable) -> T_Variable: + def as_numpy(self) -> Self: """Coerces wrapped data into a numpy array, returning a Variable.""" return self._replace(data=self.to_numpy()) @@ -1345,11 +1343,11 @@ def _to_dense(self): return self.copy(deep=False) def isel( - self: T_Variable, + self, indexers: Mapping[Any, Any] | None = None, missing_dims: ErrorOptionsWithWarn = "raise", **indexers_kwargs: Any, - ) -> T_Variable: + ) -> Self: """Return a new array indexed along the specified dimension(s). Parameters @@ -1636,7 +1634,7 @@ def transpose( self, *dims: Hashable | ellipsis, missing_dims: ErrorOptionsWithWarn = "raise", - ) -> Variable: + ) -> Self: """Return a new Variable object with transposed dimensions. Parameters @@ -1681,7 +1679,7 @@ def transpose( return self._replace(dims=dims, data=data) @property - def T(self) -> Variable: + def T(self) -> Self: return self.transpose() def set_dims(self, dims, shape=None): @@ -1789,9 +1787,7 @@ def stack(self, dimensions=None, **dimensions_kwargs): result = result._stack_once(dims, new_dim) return result - def _unstack_once_full( - self, dims: Mapping[Any, int], old_dim: Hashable - ) -> Variable: + def _unstack_once_full(self, dims: Mapping[Any, int], old_dim: Hashable) -> Self: """ Unstacks the variable without needing an index. @@ -1824,7 +1820,9 @@ def _unstack_once_full( new_data = reordered.data.reshape(new_shape) new_dims = reordered.dims[: len(other_dims)] + new_dim_names - return Variable(new_dims, new_data, self._attrs, self._encoding, fastpath=True) + return type(self)( + new_dims, new_data, self._attrs, self._encoding, fastpath=True + ) def _unstack_once( self, @@ -1832,7 +1830,7 @@ def _unstack_once( dim: Hashable, fill_value=dtypes.NA, sparse: bool = False, - ) -> Variable: + ) -> Self: """ Unstacks this variable given an index to unstack and the name of the dimension to which the index refers. @@ -2044,6 +2042,8 @@ def reduce( keep_attrs = _get_keep_attrs(default=False) attrs = self._attrs if keep_attrs else None + # We need to return `Variable` rather than the type of `self` at the moment, ref + # #8216 return Variable(dims, data, attrs=attrs) @classmethod @@ -2193,7 +2193,7 @@ def quantile( keep_attrs: bool | None = None, skipna: bool | None = None, interpolation: QuantileMethods | None = None, - ) -> Variable: + ) -> Self: """Compute the qth quantile of the data along the specified dimension. Returns the qth quantiles(s) of the array elements. diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 66bc69966d2..76dc4345ae7 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -4011,7 +4011,7 @@ def test_dot(self) -> None: assert_equal(expected5, actual5) with pytest.raises(NotImplementedError): - da.dot(dm3.to_dataset(name="dm")) # type: ignore + da.dot(dm3.to_dataset(name="dm")) with pytest.raises(TypeError): da.dot(dm3.values) # type: ignore From cdf07265f2256ebdc40a69eacae574e77f78fd6b Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Fri, 22 Sep 2023 06:48:33 -0600 Subject: [PATCH 08/46] Allow creating DataArrays with nD coordinate variables (#8126) * Allow creating DataArrays with nD coordinate variables Closes #2233 Closes #8106 * more test more test# make_aggs.bash * Fix test * Apply suggestions from code review Co-authored-by: Michael Niklas * Update test --------- Co-authored-by: Michael Niklas Co-authored-by: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> --- xarray/core/dataarray.py | 9 +-------- xarray/tests/test_dataarray.py | 36 +++++++++++++++++++++++++++++++--- 2 files changed, 34 insertions(+), 11 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 73464c07c82..724a5fc2580 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -128,7 +128,7 @@ def _check_coords_dims(shape, coords, dims): f"dimensions {dims}" ) - for d, s in zip(v.dims, v.shape): + for d, s in v.sizes.items(): if s != sizes[d]: raise ValueError( f"conflicting sizes for dimension {d!r}: " @@ -136,13 +136,6 @@ def _check_coords_dims(shape, coords, dims): f"coordinate {k!r}" ) - if k in sizes and v.shape != (sizes[k],): - raise ValueError( - f"coordinate {k!r} is a DataArray dimension, but " - f"it has shape {v.shape!r} rather than expected shape {sizes[k]!r} " - "matching the dimension size" - ) - def _infer_coords_and_dims( shape, coords, dims diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 76dc4345ae7..11ebc4da347 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -38,6 +38,7 @@ from xarray.core.indexes import Index, PandasIndex, filter_indexes_from_coords from xarray.core.types import QueryEngineOptions, QueryParserOptions from xarray.core.utils import is_scalar +from xarray.testing import _assert_internal_invariants from xarray.tests import ( InaccessibleArray, ReturnItem, @@ -415,9 +416,6 @@ def test_constructor_invalid(self) -> None: with pytest.raises(ValueError, match=r"conflicting MultiIndex"): DataArray(np.random.rand(4, 4), [("x", self.mindex), ("level_1", range(4))]) - with pytest.raises(ValueError, match=r"matching the dimension size"): - DataArray(data, coords={"x": 0}, dims=["x", "y"]) - def test_constructor_from_self_described(self) -> None: data = [[-0.1, 21], [0, 2]] expected = DataArray( @@ -7112,3 +7110,35 @@ def test_error_on_ellipsis_without_list(self) -> None: da = DataArray([[1, 2], [1, 2]], dims=("x", "y")) with pytest.raises(ValueError): da.stack(flat=...) # type: ignore + + +def test_nD_coord_dataarray() -> None: + # should succeed + da = DataArray( + np.ones((2, 4)), + dims=("x", "y"), + coords={ + "x": (("x", "y"), np.arange(8).reshape((2, 4))), + "y": ("y", np.arange(4)), + }, + ) + _assert_internal_invariants(da, check_default_indexes=True) + + da2 = DataArray(np.ones(4), dims=("y"), coords={"y": ("y", np.arange(4))}) + da3 = DataArray(np.ones(4), dims=("z")) + + _, actual = xr.align(da, da2) + assert_identical(da2, actual) + + expected = da.drop_vars("x") + _, actual = xr.broadcast(da, da2) + assert_identical(expected, actual) + + actual, _ = xr.broadcast(da, da3) + expected = da.expand_dims(z=4, axis=-1) + assert_identical(actual, expected) + + da4 = DataArray(np.ones((2, 4)), coords={"x": 0}, dims=["x", "y"]) + _assert_internal_invariants(da4, check_default_indexes=True) + assert "x" not in da4.xindexes + assert "x" in da4.coords From 24bf8046d5e8492abc91db78b096644726cf8d6e Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Date: Sat, 23 Sep 2023 12:13:13 -0700 Subject: [PATCH 09/46] Remove an import fallback (#8228) --- pyproject.toml | 1 - xarray/__init__.py | 8 ++------ 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index cb51c6ea741..25263928b20 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -96,7 +96,6 @@ module = [ "fsspec.*", "h5netcdf.*", "h5py.*", - "importlib_metadata.*", "iris.*", "matplotlib.*", "mpl_toolkits.*", diff --git a/xarray/__init__.py b/xarray/__init__.py index b63b0d81470..1fd3b0c4336 100644 --- a/xarray/__init__.py +++ b/xarray/__init__.py @@ -1,3 +1,5 @@ +from importlib.metadata import version as _version + from xarray import testing, tutorial from xarray.backends.api import ( load_dataarray, @@ -41,12 +43,6 @@ from xarray.core.variable import IndexVariable, Variable, as_variable from xarray.util.print_versions import show_versions -try: - from importlib.metadata import version as _version -except ImportError: - # if the fallback library is missing, we are doomed. - from importlib_metadata import version as _version - try: __version__ = _version("xarray") except Exception: From b14fbd9394a6195680150327d3c10fcb176bbc5f Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Date: Sat, 23 Sep 2023 12:38:05 -0700 Subject: [PATCH 10/46] Add a `Literal` typing (#8227) * Add a `Literal` typing --- xarray/core/computation.py | 2 +- xarray/tests/test_computation.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 971f036b394..bae779af652 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -893,7 +893,7 @@ def apply_ufunc( dataset_fill_value: object = _NO_FILL_VALUE, keep_attrs: bool | str | None = None, kwargs: Mapping | None = None, - dask: str = "forbidden", + dask: Literal["forbidden", "allowed", "parallelized"] = "forbidden", output_dtypes: Sequence | None = None, output_sizes: Mapping[Any, int] | None = None, meta: Any = None, diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py index b75e80db2da..87f8328e441 100644 --- a/xarray/tests/test_computation.py +++ b/xarray/tests/test_computation.py @@ -1190,7 +1190,7 @@ def test_apply_dask() -> None: # unknown setting for dask array handling with pytest.raises(ValueError): - apply_ufunc(identity, array, dask="unknown") + apply_ufunc(identity, array, dask="unknown") # type: ignore def dask_safe_identity(x): return apply_ufunc(identity, x, dask="allowed") From 77eaa8be439a61ae07939035d07d6890b74d53e8 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 24 Sep 2023 16:03:55 +0200 Subject: [PATCH 11/46] Add typing to functions related to data_vars (#8226) * Update dataset.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * more typing * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update xarray/core/dataset.py Co-authored-by: Michael Niklas * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Michael Niklas --- xarray/core/coordinates.py | 4 ++-- xarray/core/dataset.py | 13 ++++++------- xarray/core/types.py | 3 +++ 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/xarray/core/coordinates.py b/xarray/core/coordinates.py index 97ba383ebde..0c85b2a2d69 100644 --- a/xarray/core/coordinates.py +++ b/xarray/core/coordinates.py @@ -23,7 +23,7 @@ create_default_index_implicit, ) from xarray.core.merge import merge_coordinates_without_align, merge_coords -from xarray.core.types import Self, T_DataArray, T_Xarray +from xarray.core.types import DataVars, Self, T_DataArray, T_Xarray from xarray.core.utils import ( Frozen, ReprObject, @@ -937,7 +937,7 @@ def assert_coordinate_consistent(obj: T_Xarray, coords: Mapping[Any, Variable]) def create_coords_with_default_indexes( - coords: Mapping[Any, Any], data_vars: Mapping[Any, Any] | None = None + coords: Mapping[Any, Any], data_vars: DataVars | None = None ) -> Coordinates: """Returns a Coordinates object from a mapping of coordinates (arbitrary objects). diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 9d771f0390c..44016e87306 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -124,7 +124,7 @@ from xarray.backends.api import T_NetcdfEngine, T_NetcdfTypes from xarray.core.dataarray import DataArray from xarray.core.groupby import DatasetGroupBy - from xarray.core.merge import CoercibleMapping, CoercibleValue + from xarray.core.merge import CoercibleMapping, CoercibleValue, _MergeResult from xarray.core.parallelcompat import ChunkManagerEntrypoint from xarray.core.resample import DatasetResample from xarray.core.rolling import DatasetCoarsen, DatasetRolling @@ -133,6 +133,7 @@ CoarsenBoundaryOptions, CombineAttrsOptions, CompatOptions, + DataVars, DatetimeLike, DatetimeUnitOptions, Dims, @@ -404,7 +405,7 @@ def _initialize_feasible(lb, ub): return param_defaults, bounds_defaults -def merge_data_and_coords(data_vars, coords): +def merge_data_and_coords(data_vars: DataVars, coords) -> _MergeResult: """Used in Dataset.__init__.""" if isinstance(coords, Coordinates): coords = coords.copy() @@ -666,7 +667,7 @@ def __init__( self, # could make a VariableArgs to use more generally, and refine these # categories - data_vars: Mapping[Any, Any] | None = None, + data_vars: DataVars | None = None, coords: Mapping[Any, Any] | None = None, attrs: Mapping[Any, Any] | None = None, ) -> None: @@ -1220,9 +1221,7 @@ def _overwrite_indexes( else: return replaced - def copy( - self, deep: bool = False, data: Mapping[Any, ArrayLike] | None = None - ) -> Self: + def copy(self, deep: bool = False, data: DataVars | None = None) -> Self: """Returns a copy of this dataset. If `deep=True`, a deep copy is made of each of the component variables. @@ -1324,7 +1323,7 @@ def copy( def _copy( self, deep: bool = False, - data: Mapping[Any, ArrayLike] | None = None, + data: DataVars | None = None, memo: dict[int, Any] | None = None, ) -> Self: if data is None: diff --git a/xarray/core/types.py b/xarray/core/types.py index e9e700b038e..6b6f9300631 100644 --- a/xarray/core/types.py +++ b/xarray/core/types.py @@ -187,6 +187,9 @@ def copy( T_Chunks = Union[int, dict[Any, Any], Literal["auto"], None] T_NormalizedChunks = tuple[tuple[int, ...], ...] +DataVars = Mapping[Any, Any] + + ErrorOptions = Literal["raise", "ignore"] ErrorOptionsWithWarn = Literal["raise", "warn", "ignore"] From a4f80b23d32e9c3986e3342182fe382d8081c3c2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kai=20M=C3=BChlbauer?= Date: Sun, 24 Sep 2023 17:05:25 +0200 Subject: [PATCH 12/46] override `units` for datetime64/timedelta64 variables to preserve integer dtype (#8201) * remove `dtype` from encoding for datetime64/timedelta64 variables to prevent unnecessary casts * adapt tests * add whats-new.rst entry * Update xarray/coding/times.py Co-authored-by: Spencer Clark * Update doc/whats-new.rst Co-authored-by: Spencer Clark * add test per review suggestion, replace .kind-check with np.issubdtype-check * align timedelta64 check with datetime64 check * override units instead of dtype * remove print statement * warn in case of serialization to floating point, too * align if-else * Add instructions to warnings * Fix test * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * use warnings.catch_warnings * Update doc/whats-new.rst Co-authored-by: Spencer Clark --------- Co-authored-by: Spencer Clark Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- doc/whats-new.rst | 2 + xarray/coding/times.py | 110 +++++++++++++++++++++--------- xarray/tests/test_coding_times.py | 65 +++++++++++++++--- 3 files changed, 132 insertions(+), 45 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 67429ed7e18..5f18e999cc0 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -90,6 +90,8 @@ Bug fixes - ``.rolling_exp`` functions no longer mistakenly lose non-dimensioned coords (:issue:`6528`, :pull:`8114`) By `Maximilian Roos `_. +- In the event that user-provided datetime64/timedelta64 units and integer dtype encoding parameters conflict with each other, override the units to preserve an integer dtype for most faithful serialization to disk (:issue:`1064`, :pull:`8201`). + By `Kai Mühlbauer `_. Documentation ~~~~~~~~~~~~~ diff --git a/xarray/coding/times.py b/xarray/coding/times.py index 79efbecfb7c..2822f02dd8d 100644 --- a/xarray/coding/times.py +++ b/xarray/coding/times.py @@ -656,8 +656,22 @@ def cast_to_int_if_safe(num) -> np.ndarray: return num +def _division(deltas, delta, floor): + if floor: + # calculate int64 floor division + # to preserve integer dtype if possible (GH 4045, GH7817). + num = deltas // delta.astype(np.int64) + num = num.astype(np.int64, copy=False) + else: + num = deltas / delta + return num + + def encode_cf_datetime( - dates, units: str | None = None, calendar: str | None = None + dates, + units: str | None = None, + calendar: str | None = None, + dtype: np.dtype | None = None, ) -> tuple[np.ndarray, str, str]: """Given an array of datetime objects, returns the tuple `(num, units, calendar)` suitable for a CF compliant time variable. @@ -689,6 +703,12 @@ def encode_cf_datetime( time_units, ref_date = _unpack_time_units_and_ref_date(units) time_delta = _time_units_to_timedelta64(time_units) + # Wrap the dates in a DatetimeIndex to do the subtraction to ensure + # an OverflowError is raised if the ref_date is too far away from + # dates to be encoded (GH 2272). + dates_as_index = pd.DatetimeIndex(dates.ravel()) + time_deltas = dates_as_index - ref_date + # retrieve needed units to faithfully encode to int64 needed_units, data_ref_date = _unpack_time_units_and_ref_date(data_units) if data_units != units: @@ -697,26 +717,32 @@ def encode_cf_datetime( if ref_delta > np.timedelta64(0, "ns"): needed_units = _infer_time_units_from_diff(ref_delta) - # Wrap the dates in a DatetimeIndex to do the subtraction to ensure - # an OverflowError is raised if the ref_date is too far away from - # dates to be encoded (GH 2272). - dates_as_index = pd.DatetimeIndex(dates.ravel()) - time_deltas = dates_as_index - ref_date - # needed time delta to encode faithfully to int64 needed_time_delta = _time_units_to_timedelta64(needed_units) - if time_delta <= needed_time_delta: - # calculate int64 floor division - # to preserve integer dtype if possible (GH 4045, GH7817). - num = time_deltas // time_delta.astype(np.int64) - num = num.astype(np.int64, copy=False) - else: - emit_user_level_warning( - f"Times can't be serialized faithfully with requested units {units!r}. " - f"Resolution of {needed_units!r} needed. " - f"Serializing timeseries to floating point." - ) - num = time_deltas / time_delta + + floor_division = True + if time_delta > needed_time_delta: + floor_division = False + if dtype is None: + emit_user_level_warning( + f"Times can't be serialized faithfully to int64 with requested units {units!r}. " + f"Resolution of {needed_units!r} needed. Serializing times to floating point instead. " + f"Set encoding['dtype'] to integer dtype to serialize to int64. " + f"Set encoding['dtype'] to floating point dtype to silence this warning." + ) + elif np.issubdtype(dtype, np.integer): + new_units = f"{needed_units} since {format_timestamp(ref_date)}" + emit_user_level_warning( + f"Times can't be serialized faithfully to int64 with requested units {units!r}. " + f"Serializing with units {new_units!r} instead. " + f"Set encoding['dtype'] to floating point dtype to serialize with units {units!r}. " + f"Set encoding['units'] to {new_units!r} to silence this warning ." + ) + units = new_units + time_delta = needed_time_delta + floor_division = True + + num = _division(time_deltas, time_delta, floor_division) num = num.values.reshape(dates.shape) except (OutOfBoundsDatetime, OverflowError, ValueError): @@ -728,7 +754,9 @@ def encode_cf_datetime( return (num, units, calendar) -def encode_cf_timedelta(timedeltas, units: str | None = None) -> tuple[np.ndarray, str]: +def encode_cf_timedelta( + timedeltas, units: str | None = None, dtype: np.dtype | None = None +) -> tuple[np.ndarray, str]: data_units = infer_timedelta_units(timedeltas) if units is None: @@ -744,18 +772,29 @@ def encode_cf_timedelta(timedeltas, units: str | None = None) -> tuple[np.ndarra # needed time delta to encode faithfully to int64 needed_time_delta = _time_units_to_timedelta64(needed_units) - if time_delta <= needed_time_delta: - # calculate int64 floor division - # to preserve integer dtype if possible - num = time_deltas // time_delta.astype(np.int64) - num = num.astype(np.int64, copy=False) - else: - emit_user_level_warning( - f"Timedeltas can't be serialized faithfully with requested units {units!r}. " - f"Resolution of {needed_units!r} needed. " - f"Serializing timedeltas to floating point." - ) - num = time_deltas / time_delta + + floor_division = True + if time_delta > needed_time_delta: + floor_division = False + if dtype is None: + emit_user_level_warning( + f"Timedeltas can't be serialized faithfully to int64 with requested units {units!r}. " + f"Resolution of {needed_units!r} needed. Serializing timeseries to floating point instead. " + f"Set encoding['dtype'] to integer dtype to serialize to int64. " + f"Set encoding['dtype'] to floating point dtype to silence this warning." + ) + elif np.issubdtype(dtype, np.integer): + emit_user_level_warning( + f"Timedeltas can't be serialized faithfully with requested units {units!r}. " + f"Serializing with units {needed_units!r} instead. " + f"Set encoding['dtype'] to floating point dtype to serialize with units {units!r}. " + f"Set encoding['units'] to {needed_units!r} to silence this warning ." + ) + units = needed_units + time_delta = needed_time_delta + floor_division = True + + num = _division(time_deltas, time_delta, floor_division) num = num.values.reshape(timedeltas.shape) return (num, units) @@ -772,7 +811,8 @@ def encode(self, variable: Variable, name: T_Name = None) -> Variable: units = encoding.pop("units", None) calendar = encoding.pop("calendar", None) - (data, units, calendar) = encode_cf_datetime(data, units, calendar) + dtype = encoding.get("dtype", None) + (data, units, calendar) = encode_cf_datetime(data, units, calendar, dtype) safe_setitem(attrs, "units", units, name=name) safe_setitem(attrs, "calendar", calendar, name=name) @@ -807,7 +847,9 @@ def encode(self, variable: Variable, name: T_Name = None) -> Variable: if np.issubdtype(variable.data.dtype, np.timedelta64): dims, data, attrs, encoding = unpack_for_encoding(variable) - data, units = encode_cf_timedelta(data, encoding.pop("units", None)) + data, units = encode_cf_timedelta( + data, encoding.pop("units", None), encoding.get("dtype", None) + ) safe_setitem(attrs, "units", units, name=name) return Variable(dims, data, attrs, encoding, fastpath=True) diff --git a/xarray/tests/test_coding_times.py b/xarray/tests/test_coding_times.py index 079e432b565..5f76a4a2ca8 100644 --- a/xarray/tests/test_coding_times.py +++ b/xarray/tests/test_coding_times.py @@ -30,7 +30,7 @@ from xarray.coding.variables import SerializationWarning from xarray.conventions import _update_bounds_attributes, cf_encoder from xarray.core.common import contains_cftime_datetimes -from xarray.testing import assert_allclose, assert_equal, assert_identical +from xarray.testing import assert_equal, assert_identical from xarray.tests import ( FirstElementAccessibleArray, arm_xfail, @@ -1036,7 +1036,7 @@ def test_encode_cf_datetime_defaults_to_correct_dtype( pytest.skip("Nanosecond frequency is not valid for cftime dates.") times = date_range("2000", periods=3, freq=freq) units = f"{encoding_units} since 2000-01-01" - encoded, _, _ = coding.times.encode_cf_datetime(times, units) + encoded, _units, _ = coding.times.encode_cf_datetime(times, units) numpy_timeunit = coding.times._netcdf_to_numpy_timeunit(encoding_units) encoding_units_as_timedelta = np.timedelta64(1, numpy_timeunit) @@ -1212,6 +1212,7 @@ def test_contains_cftime_lazy() -> None: ("1677-09-21T00:12:43.145224193", "ns", np.int64, None, False), ("1677-09-21T00:12:43.145225", "us", np.int64, None, False), ("1970-01-01T00:00:01.000001", "us", np.int64, None, False), + ("1677-09-21T00:21:52.901038080", "ns", np.float32, 20.0, True), ], ) def test_roundtrip_datetime64_nanosecond_precision( @@ -1261,14 +1262,52 @@ def test_roundtrip_datetime64_nanosecond_precision_warning() -> None: ] units = "days since 1970-01-10T01:01:00" needed_units = "hours" - encoding = dict(_FillValue=20, units=units) + new_units = f"{needed_units} since 1970-01-10T01:01:00" + + encoding = dict(dtype=None, _FillValue=20, units=units) var = Variable(["time"], times, encoding=encoding) - wmsg = ( - f"Times can't be serialized faithfully with requested units {units!r}. " - f"Resolution of {needed_units!r} needed. " - ) - with pytest.warns(UserWarning, match=wmsg): + with pytest.warns(UserWarning, match=f"Resolution of {needed_units!r} needed."): + encoded_var = conventions.encode_cf_variable(var) + assert encoded_var.dtype == np.float64 + assert encoded_var.attrs["units"] == units + assert encoded_var.attrs["_FillValue"] == 20.0 + + decoded_var = conventions.decode_cf_variable("foo", encoded_var) + assert_identical(var, decoded_var) + + encoding = dict(dtype="int64", _FillValue=20, units=units) + var = Variable(["time"], times, encoding=encoding) + with pytest.warns( + UserWarning, match=f"Serializing with units {new_units!r} instead." + ): + encoded_var = conventions.encode_cf_variable(var) + assert encoded_var.dtype == np.int64 + assert encoded_var.attrs["units"] == new_units + assert encoded_var.attrs["_FillValue"] == 20 + + decoded_var = conventions.decode_cf_variable("foo", encoded_var) + assert_identical(var, decoded_var) + + encoding = dict(dtype="float64", _FillValue=20, units=units) + var = Variable(["time"], times, encoding=encoding) + with warnings.catch_warnings(): + warnings.simplefilter("error") + encoded_var = conventions.encode_cf_variable(var) + assert encoded_var.dtype == np.float64 + assert encoded_var.attrs["units"] == units + assert encoded_var.attrs["_FillValue"] == 20.0 + + decoded_var = conventions.decode_cf_variable("foo", encoded_var) + assert_identical(var, decoded_var) + + encoding = dict(dtype="int64", _FillValue=20, units=new_units) + var = Variable(["time"], times, encoding=encoding) + with warnings.catch_warnings(): + warnings.simplefilter("error") encoded_var = conventions.encode_cf_variable(var) + assert encoded_var.dtype == np.int64 + assert encoded_var.attrs["units"] == new_units + assert encoded_var.attrs["_FillValue"] == 20 decoded_var = conventions.decode_cf_variable("foo", encoded_var) assert_identical(var, decoded_var) @@ -1309,14 +1348,18 @@ def test_roundtrip_timedelta64_nanosecond_precision_warning() -> None: needed_units = "hours" wmsg = ( f"Timedeltas can't be serialized faithfully with requested units {units!r}. " - f"Resolution of {needed_units!r} needed. " + f"Serializing with units {needed_units!r} instead." ) - encoding = dict(_FillValue=20, units=units) + encoding = dict(dtype=np.int64, _FillValue=20, units=units) var = Variable(["time"], timedelta_values, encoding=encoding) with pytest.warns(UserWarning, match=wmsg): encoded_var = conventions.encode_cf_variable(var) + assert encoded_var.dtype == np.int64 + assert encoded_var.attrs["units"] == needed_units + assert encoded_var.attrs["_FillValue"] == 20 decoded_var = conventions.decode_cf_variable("foo", encoded_var) - assert_allclose(var, decoded_var) + assert_identical(var, decoded_var) + assert decoded_var.encoding["dtype"] == np.int64 def test_roundtrip_float_times() -> None: From 05b3a211d5b93acd45be19eab1d0a6e9c72cff8d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kai=20M=C3=BChlbauer?= Date: Sun, 24 Sep 2023 17:28:30 +0200 Subject: [PATCH 13/46] test_interpolate_pd_compat with range of fill_value's (#8189) * ENH: test_interpolate_pd_compat with range of fill_value's * add whats-new.rst entry --- doc/whats-new.rst | 3 +++ xarray/tests/test_missing.py | 28 +++++++++++++++++++--------- 2 files changed, 22 insertions(+), 9 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 5f18e999cc0..9a21bcb7ab9 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -108,6 +108,9 @@ Internal Changes than `.reduce`, as the start of a broader effort to move non-reducing functions away from ```.reduce``, (:pull:`8114`). By `Maximilian Roos `_. +- Test range of fill_value's in test_interpolate_pd_compat (:issue:`8146`, :pull:`8189`). + By `Kai Mühlbauer `_. + .. _whats-new.2023.08.0: diff --git a/xarray/tests/test_missing.py b/xarray/tests/test_missing.py index fe2cdc58807..c57d84c927d 100644 --- a/xarray/tests/test_missing.py +++ b/xarray/tests/test_missing.py @@ -92,26 +92,36 @@ def make_interpolate_example_data(shape, frac_nan, seed=12345, non_uniform=False return da, df +@pytest.mark.parametrize("fill_value", [None, np.nan, 47.11]) +@pytest.mark.parametrize( + "method", ["linear", "nearest", "zero", "slinear", "quadratic", "cubic"] +) @requires_scipy -def test_interpolate_pd_compat(): +def test_interpolate_pd_compat(method, fill_value) -> None: shapes = [(8, 8), (1, 20), (20, 1), (100, 100)] frac_nans = [0, 0.5, 1] - methods = ["linear", "nearest", "zero", "slinear", "quadratic", "cubic"] - for shape, frac_nan, method in itertools.product(shapes, frac_nans, methods): + for shape, frac_nan in itertools.product(shapes, frac_nans): da, df = make_interpolate_example_data(shape, frac_nan) for dim in ["time", "x"]: - actual = da.interpolate_na(method=method, dim=dim, fill_value=np.nan) + actual = da.interpolate_na(method=method, dim=dim, fill_value=fill_value) + # need limit_direction="both" here, to let pandas fill + # in both directions instead of default forward direction only expected = df.interpolate( method=method, axis=da.get_axis_num(dim), + limit_direction="both", + fill_value=fill_value, ) - # Note, Pandas does some odd things with the left/right fill_value - # for the linear methods. This next line inforces the xarray - # fill_value convention on the pandas output. Therefore, this test - # only checks that interpolated values are the same (not nans) - expected.values[pd.isnull(actual.values)] = np.nan + + if method == "linear": + # Note, Pandas does not take left/right fill_value into account + # for the numpy linear methods. + # see https://github.com/pandas-dev/pandas/issues/55144 + # This aligns the pandas output with the xarray output + expected.values[pd.isnull(actual.values)] = np.nan + expected.values[actual.values == fill_value] = fill_value np.testing.assert_allclose(actual.values, expected.values) From 565b23b95beda893e0d66d1e2c6da49984bb0925 Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Mon, 25 Sep 2023 06:43:54 +0200 Subject: [PATCH 14/46] Rewrite typed_ops (#8204) * rewrite typed_ops * improved typing of rolling instance attrs * add typed_ops xr.Variable tests * add typed_ops test * add minor typehint * adjust to numpy 1.24 * add groupby ops type tests * remove wrong types from ops * fix Dataset not being part of SupportsArray Protocol * ignore mypy align complaint * add reasons for type ignores in test * add overloads for variable typed ops * move tests to their own module * add entry to whats-new --- doc/whats-new.rst | 3 + xarray/core/_typed_ops.py | 591 ++++++++++++++++--------- xarray/core/_typed_ops.pyi | 782 --------------------------------- xarray/core/dataarray.py | 15 +- xarray/core/dataset.py | 21 +- xarray/core/rolling.py | 28 +- xarray/core/types.py | 7 +- xarray/core/weighted.py | 1 + xarray/tests/test_groupby.py | 4 +- xarray/tests/test_typed_ops.py | 246 +++++++++++ xarray/util/generate_ops.py | 286 ++++++------ 11 files changed, 827 insertions(+), 1157 deletions(-) delete mode 100644 xarray/core/_typed_ops.pyi create mode 100644 xarray/tests/test_typed_ops.py diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 9a21bcb7ab9..4307c2829ca 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -92,6 +92,9 @@ Bug fixes By `Maximilian Roos `_. - In the event that user-provided datetime64/timedelta64 units and integer dtype encoding parameters conflict with each other, override the units to preserve an integer dtype for most faithful serialization to disk (:issue:`1064`, :pull:`8201`). By `Kai Mühlbauer `_. +- Static typing of dunder ops methods (like :py:meth:`DataArray.__eq__`) has been fixed. + Remaining issues are upstream problems (:issue:`7780`, :pull:`8204`). + By `Michael Niklas `_. Documentation ~~~~~~~~~~~~~ diff --git a/xarray/core/_typed_ops.py b/xarray/core/_typed_ops.py index d3a783be45d..330d13bb217 100644 --- a/xarray/core/_typed_ops.py +++ b/xarray/core/_typed_ops.py @@ -1,165 +1,182 @@ """Mixin classes with arithmetic operators.""" # This file was generated using xarray.util.generate_ops. Do not edit manually. +from __future__ import annotations + import operator +from typing import TYPE_CHECKING, Any, Callable, NoReturn, overload from xarray.core import nputils, ops +from xarray.core.types import ( + DaCompatible, + DsCompatible, + GroupByCompatible, + Self, + T_DataArray, + T_Xarray, + VarCompatible, +) + +if TYPE_CHECKING: + from xarray.core.dataset import Dataset class DatasetOpsMixin: __slots__ = () - def _binary_op(self, other, f, reflexive=False): + def _binary_op( + self, other: DsCompatible, f: Callable, reflexive: bool = False + ) -> Self: raise NotImplementedError - def __add__(self, other): + def __add__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.add) - def __sub__(self, other): + def __sub__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.sub) - def __mul__(self, other): + def __mul__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.mul) - def __pow__(self, other): + def __pow__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.pow) - def __truediv__(self, other): + def __truediv__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.truediv) - def __floordiv__(self, other): + def __floordiv__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.floordiv) - def __mod__(self, other): + def __mod__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.mod) - def __and__(self, other): + def __and__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.and_) - def __xor__(self, other): + def __xor__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.xor) - def __or__(self, other): + def __or__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.or_) - def __lshift__(self, other): + def __lshift__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.lshift) - def __rshift__(self, other): + def __rshift__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.rshift) - def __lt__(self, other): + def __lt__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.lt) - def __le__(self, other): + def __le__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.le) - def __gt__(self, other): + def __gt__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.gt) - def __ge__(self, other): + def __ge__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.ge) - def __eq__(self, other): + def __eq__(self, other: DsCompatible) -> Self: # type:ignore[override] return self._binary_op(other, nputils.array_eq) - def __ne__(self, other): + def __ne__(self, other: DsCompatible) -> Self: # type:ignore[override] return self._binary_op(other, nputils.array_ne) - def __radd__(self, other): + def __radd__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.add, reflexive=True) - def __rsub__(self, other): + def __rsub__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.sub, reflexive=True) - def __rmul__(self, other): + def __rmul__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.mul, reflexive=True) - def __rpow__(self, other): + def __rpow__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.pow, reflexive=True) - def __rtruediv__(self, other): + def __rtruediv__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.truediv, reflexive=True) - def __rfloordiv__(self, other): + def __rfloordiv__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.floordiv, reflexive=True) - def __rmod__(self, other): + def __rmod__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.mod, reflexive=True) - def __rand__(self, other): + def __rand__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.and_, reflexive=True) - def __rxor__(self, other): + def __rxor__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.xor, reflexive=True) - def __ror__(self, other): + def __ror__(self, other: DsCompatible) -> Self: return self._binary_op(other, operator.or_, reflexive=True) - def _inplace_binary_op(self, other, f): + def _inplace_binary_op(self, other: DsCompatible, f: Callable) -> Self: raise NotImplementedError - def __iadd__(self, other): + def __iadd__(self, other: DsCompatible) -> Self: return self._inplace_binary_op(other, operator.iadd) - def __isub__(self, other): + def __isub__(self, other: DsCompatible) -> Self: return self._inplace_binary_op(other, operator.isub) - def __imul__(self, other): + def __imul__(self, other: DsCompatible) -> Self: return self._inplace_binary_op(other, operator.imul) - def __ipow__(self, other): + def __ipow__(self, other: DsCompatible) -> Self: return self._inplace_binary_op(other, operator.ipow) - def __itruediv__(self, other): + def __itruediv__(self, other: DsCompatible) -> Self: return self._inplace_binary_op(other, operator.itruediv) - def __ifloordiv__(self, other): + def __ifloordiv__(self, other: DsCompatible) -> Self: return self._inplace_binary_op(other, operator.ifloordiv) - def __imod__(self, other): + def __imod__(self, other: DsCompatible) -> Self: return self._inplace_binary_op(other, operator.imod) - def __iand__(self, other): + def __iand__(self, other: DsCompatible) -> Self: return self._inplace_binary_op(other, operator.iand) - def __ixor__(self, other): + def __ixor__(self, other: DsCompatible) -> Self: return self._inplace_binary_op(other, operator.ixor) - def __ior__(self, other): + def __ior__(self, other: DsCompatible) -> Self: return self._inplace_binary_op(other, operator.ior) - def __ilshift__(self, other): + def __ilshift__(self, other: DsCompatible) -> Self: return self._inplace_binary_op(other, operator.ilshift) - def __irshift__(self, other): + def __irshift__(self, other: DsCompatible) -> Self: return self._inplace_binary_op(other, operator.irshift) - def _unary_op(self, f, *args, **kwargs): + def _unary_op(self, f: Callable, *args: Any, **kwargs: Any) -> Self: raise NotImplementedError - def __neg__(self): + def __neg__(self) -> Self: return self._unary_op(operator.neg) - def __pos__(self): + def __pos__(self) -> Self: return self._unary_op(operator.pos) - def __abs__(self): + def __abs__(self) -> Self: return self._unary_op(operator.abs) - def __invert__(self): + def __invert__(self) -> Self: return self._unary_op(operator.invert) - def round(self, *args, **kwargs): + def round(self, *args: Any, **kwargs: Any) -> Self: return self._unary_op(ops.round_, *args, **kwargs) - def argsort(self, *args, **kwargs): + def argsort(self, *args: Any, **kwargs: Any) -> Self: return self._unary_op(ops.argsort, *args, **kwargs) - def conj(self, *args, **kwargs): + def conj(self, *args: Any, **kwargs: Any) -> Self: return self._unary_op(ops.conj, *args, **kwargs) - def conjugate(self, *args, **kwargs): + def conjugate(self, *args: Any, **kwargs: Any) -> Self: return self._unary_op(ops.conjugate, *args, **kwargs) __add__.__doc__ = operator.add.__doc__ @@ -215,157 +232,159 @@ def conjugate(self, *args, **kwargs): class DataArrayOpsMixin: __slots__ = () - def _binary_op(self, other, f, reflexive=False): + def _binary_op( + self, other: DaCompatible, f: Callable, reflexive: bool = False + ) -> Self: raise NotImplementedError - def __add__(self, other): + def __add__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.add) - def __sub__(self, other): + def __sub__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.sub) - def __mul__(self, other): + def __mul__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.mul) - def __pow__(self, other): + def __pow__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.pow) - def __truediv__(self, other): + def __truediv__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.truediv) - def __floordiv__(self, other): + def __floordiv__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.floordiv) - def __mod__(self, other): + def __mod__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.mod) - def __and__(self, other): + def __and__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.and_) - def __xor__(self, other): + def __xor__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.xor) - def __or__(self, other): + def __or__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.or_) - def __lshift__(self, other): + def __lshift__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.lshift) - def __rshift__(self, other): + def __rshift__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.rshift) - def __lt__(self, other): + def __lt__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.lt) - def __le__(self, other): + def __le__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.le) - def __gt__(self, other): + def __gt__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.gt) - def __ge__(self, other): + def __ge__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.ge) - def __eq__(self, other): + def __eq__(self, other: DaCompatible) -> Self: # type:ignore[override] return self._binary_op(other, nputils.array_eq) - def __ne__(self, other): + def __ne__(self, other: DaCompatible) -> Self: # type:ignore[override] return self._binary_op(other, nputils.array_ne) - def __radd__(self, other): + def __radd__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.add, reflexive=True) - def __rsub__(self, other): + def __rsub__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.sub, reflexive=True) - def __rmul__(self, other): + def __rmul__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.mul, reflexive=True) - def __rpow__(self, other): + def __rpow__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.pow, reflexive=True) - def __rtruediv__(self, other): + def __rtruediv__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.truediv, reflexive=True) - def __rfloordiv__(self, other): + def __rfloordiv__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.floordiv, reflexive=True) - def __rmod__(self, other): + def __rmod__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.mod, reflexive=True) - def __rand__(self, other): + def __rand__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.and_, reflexive=True) - def __rxor__(self, other): + def __rxor__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.xor, reflexive=True) - def __ror__(self, other): + def __ror__(self, other: DaCompatible) -> Self: return self._binary_op(other, operator.or_, reflexive=True) - def _inplace_binary_op(self, other, f): + def _inplace_binary_op(self, other: DaCompatible, f: Callable) -> Self: raise NotImplementedError - def __iadd__(self, other): + def __iadd__(self, other: DaCompatible) -> Self: return self._inplace_binary_op(other, operator.iadd) - def __isub__(self, other): + def __isub__(self, other: DaCompatible) -> Self: return self._inplace_binary_op(other, operator.isub) - def __imul__(self, other): + def __imul__(self, other: DaCompatible) -> Self: return self._inplace_binary_op(other, operator.imul) - def __ipow__(self, other): + def __ipow__(self, other: DaCompatible) -> Self: return self._inplace_binary_op(other, operator.ipow) - def __itruediv__(self, other): + def __itruediv__(self, other: DaCompatible) -> Self: return self._inplace_binary_op(other, operator.itruediv) - def __ifloordiv__(self, other): + def __ifloordiv__(self, other: DaCompatible) -> Self: return self._inplace_binary_op(other, operator.ifloordiv) - def __imod__(self, other): + def __imod__(self, other: DaCompatible) -> Self: return self._inplace_binary_op(other, operator.imod) - def __iand__(self, other): + def __iand__(self, other: DaCompatible) -> Self: return self._inplace_binary_op(other, operator.iand) - def __ixor__(self, other): + def __ixor__(self, other: DaCompatible) -> Self: return self._inplace_binary_op(other, operator.ixor) - def __ior__(self, other): + def __ior__(self, other: DaCompatible) -> Self: return self._inplace_binary_op(other, operator.ior) - def __ilshift__(self, other): + def __ilshift__(self, other: DaCompatible) -> Self: return self._inplace_binary_op(other, operator.ilshift) - def __irshift__(self, other): + def __irshift__(self, other: DaCompatible) -> Self: return self._inplace_binary_op(other, operator.irshift) - def _unary_op(self, f, *args, **kwargs): + def _unary_op(self, f: Callable, *args: Any, **kwargs: Any) -> Self: raise NotImplementedError - def __neg__(self): + def __neg__(self) -> Self: return self._unary_op(operator.neg) - def __pos__(self): + def __pos__(self) -> Self: return self._unary_op(operator.pos) - def __abs__(self): + def __abs__(self) -> Self: return self._unary_op(operator.abs) - def __invert__(self): + def __invert__(self) -> Self: return self._unary_op(operator.invert) - def round(self, *args, **kwargs): + def round(self, *args: Any, **kwargs: Any) -> Self: return self._unary_op(ops.round_, *args, **kwargs) - def argsort(self, *args, **kwargs): + def argsort(self, *args: Any, **kwargs: Any) -> Self: return self._unary_op(ops.argsort, *args, **kwargs) - def conj(self, *args, **kwargs): + def conj(self, *args: Any, **kwargs: Any) -> Self: return self._unary_op(ops.conj, *args, **kwargs) - def conjugate(self, *args, **kwargs): + def conjugate(self, *args: Any, **kwargs: Any) -> Self: return self._unary_op(ops.conjugate, *args, **kwargs) __add__.__doc__ = operator.add.__doc__ @@ -421,157 +440,303 @@ def conjugate(self, *args, **kwargs): class VariableOpsMixin: __slots__ = () - def _binary_op(self, other, f, reflexive=False): + def _binary_op( + self, other: VarCompatible, f: Callable, reflexive: bool = False + ) -> Self: raise NotImplementedError - def __add__(self, other): + @overload + def __add__(self, other: T_DataArray) -> NoReturn: + ... + + @overload + def __add__(self, other: VarCompatible) -> Self: + ... + + def __add__(self, other: VarCompatible) -> Self: return self._binary_op(other, operator.add) - def __sub__(self, other): + @overload + def __sub__(self, other: T_DataArray) -> NoReturn: + ... + + @overload + def __sub__(self, other: VarCompatible) -> Self: + ... + + def __sub__(self, other: VarCompatible) -> Self: return self._binary_op(other, operator.sub) - def __mul__(self, other): + @overload + def __mul__(self, other: T_DataArray) -> NoReturn: + ... + + @overload + def __mul__(self, other: VarCompatible) -> Self: + ... + + def __mul__(self, other: VarCompatible) -> Self: return self._binary_op(other, operator.mul) - def __pow__(self, other): + @overload + def __pow__(self, other: T_DataArray) -> NoReturn: + ... + + @overload + def __pow__(self, other: VarCompatible) -> Self: + ... + + def __pow__(self, other: VarCompatible) -> Self: return self._binary_op(other, operator.pow) - def __truediv__(self, other): + @overload + def __truediv__(self, other: T_DataArray) -> NoReturn: + ... + + @overload + def __truediv__(self, other: VarCompatible) -> Self: + ... + + def __truediv__(self, other: VarCompatible) -> Self: return self._binary_op(other, operator.truediv) - def __floordiv__(self, other): + @overload + def __floordiv__(self, other: T_DataArray) -> NoReturn: + ... + + @overload + def __floordiv__(self, other: VarCompatible) -> Self: + ... + + def __floordiv__(self, other: VarCompatible) -> Self: return self._binary_op(other, operator.floordiv) - def __mod__(self, other): + @overload + def __mod__(self, other: T_DataArray) -> NoReturn: + ... + + @overload + def __mod__(self, other: VarCompatible) -> Self: + ... + + def __mod__(self, other: VarCompatible) -> Self: return self._binary_op(other, operator.mod) - def __and__(self, other): + @overload + def __and__(self, other: T_DataArray) -> NoReturn: + ... + + @overload + def __and__(self, other: VarCompatible) -> Self: + ... + + def __and__(self, other: VarCompatible) -> Self: return self._binary_op(other, operator.and_) - def __xor__(self, other): + @overload + def __xor__(self, other: T_DataArray) -> NoReturn: + ... + + @overload + def __xor__(self, other: VarCompatible) -> Self: + ... + + def __xor__(self, other: VarCompatible) -> Self: return self._binary_op(other, operator.xor) - def __or__(self, other): + @overload + def __or__(self, other: T_DataArray) -> NoReturn: + ... + + @overload + def __or__(self, other: VarCompatible) -> Self: + ... + + def __or__(self, other: VarCompatible) -> Self: return self._binary_op(other, operator.or_) - def __lshift__(self, other): + @overload + def __lshift__(self, other: T_DataArray) -> NoReturn: + ... + + @overload + def __lshift__(self, other: VarCompatible) -> Self: + ... + + def __lshift__(self, other: VarCompatible) -> Self: return self._binary_op(other, operator.lshift) - def __rshift__(self, other): + @overload + def __rshift__(self, other: T_DataArray) -> NoReturn: + ... + + @overload + def __rshift__(self, other: VarCompatible) -> Self: + ... + + def __rshift__(self, other: VarCompatible) -> Self: return self._binary_op(other, operator.rshift) - def __lt__(self, other): + @overload + def __lt__(self, other: T_DataArray) -> NoReturn: + ... + + @overload + def __lt__(self, other: VarCompatible) -> Self: + ... + + def __lt__(self, other: VarCompatible) -> Self: return self._binary_op(other, operator.lt) - def __le__(self, other): + @overload + def __le__(self, other: T_DataArray) -> NoReturn: + ... + + @overload + def __le__(self, other: VarCompatible) -> Self: + ... + + def __le__(self, other: VarCompatible) -> Self: return self._binary_op(other, operator.le) - def __gt__(self, other): + @overload + def __gt__(self, other: T_DataArray) -> NoReturn: + ... + + @overload + def __gt__(self, other: VarCompatible) -> Self: + ... + + def __gt__(self, other: VarCompatible) -> Self: return self._binary_op(other, operator.gt) - def __ge__(self, other): + @overload + def __ge__(self, other: T_DataArray) -> NoReturn: + ... + + @overload + def __ge__(self, other: VarCompatible) -> Self: + ... + + def __ge__(self, other: VarCompatible) -> Self: return self._binary_op(other, operator.ge) - def __eq__(self, other): + @overload # type:ignore[override] + def __eq__(self, other: T_DataArray) -> NoReturn: + ... + + @overload + def __eq__(self, other: VarCompatible) -> Self: + ... + + def __eq__(self, other: VarCompatible) -> Self: return self._binary_op(other, nputils.array_eq) - def __ne__(self, other): + @overload # type:ignore[override] + def __ne__(self, other: T_DataArray) -> NoReturn: + ... + + @overload + def __ne__(self, other: VarCompatible) -> Self: + ... + + def __ne__(self, other: VarCompatible) -> Self: return self._binary_op(other, nputils.array_ne) - def __radd__(self, other): + def __radd__(self, other: VarCompatible) -> Self: return self._binary_op(other, operator.add, reflexive=True) - def __rsub__(self, other): + def __rsub__(self, other: VarCompatible) -> Self: return self._binary_op(other, operator.sub, reflexive=True) - def __rmul__(self, other): + def __rmul__(self, other: VarCompatible) -> Self: return self._binary_op(other, operator.mul, reflexive=True) - def __rpow__(self, other): + def __rpow__(self, other: VarCompatible) -> Self: return self._binary_op(other, operator.pow, reflexive=True) - def __rtruediv__(self, other): + def __rtruediv__(self, other: VarCompatible) -> Self: return self._binary_op(other, operator.truediv, reflexive=True) - def __rfloordiv__(self, other): + def __rfloordiv__(self, other: VarCompatible) -> Self: return self._binary_op(other, operator.floordiv, reflexive=True) - def __rmod__(self, other): + def __rmod__(self, other: VarCompatible) -> Self: return self._binary_op(other, operator.mod, reflexive=True) - def __rand__(self, other): + def __rand__(self, other: VarCompatible) -> Self: return self._binary_op(other, operator.and_, reflexive=True) - def __rxor__(self, other): + def __rxor__(self, other: VarCompatible) -> Self: return self._binary_op(other, operator.xor, reflexive=True) - def __ror__(self, other): + def __ror__(self, other: VarCompatible) -> Self: return self._binary_op(other, operator.or_, reflexive=True) - def _inplace_binary_op(self, other, f): + def _inplace_binary_op(self, other: VarCompatible, f: Callable) -> Self: raise NotImplementedError - def __iadd__(self, other): + def __iadd__(self, other: VarCompatible) -> Self: # type:ignore[misc] return self._inplace_binary_op(other, operator.iadd) - def __isub__(self, other): + def __isub__(self, other: VarCompatible) -> Self: # type:ignore[misc] return self._inplace_binary_op(other, operator.isub) - def __imul__(self, other): + def __imul__(self, other: VarCompatible) -> Self: # type:ignore[misc] return self._inplace_binary_op(other, operator.imul) - def __ipow__(self, other): + def __ipow__(self, other: VarCompatible) -> Self: # type:ignore[misc] return self._inplace_binary_op(other, operator.ipow) - def __itruediv__(self, other): + def __itruediv__(self, other: VarCompatible) -> Self: # type:ignore[misc] return self._inplace_binary_op(other, operator.itruediv) - def __ifloordiv__(self, other): + def __ifloordiv__(self, other: VarCompatible) -> Self: # type:ignore[misc] return self._inplace_binary_op(other, operator.ifloordiv) - def __imod__(self, other): + def __imod__(self, other: VarCompatible) -> Self: # type:ignore[misc] return self._inplace_binary_op(other, operator.imod) - def __iand__(self, other): + def __iand__(self, other: VarCompatible) -> Self: # type:ignore[misc] return self._inplace_binary_op(other, operator.iand) - def __ixor__(self, other): + def __ixor__(self, other: VarCompatible) -> Self: # type:ignore[misc] return self._inplace_binary_op(other, operator.ixor) - def __ior__(self, other): + def __ior__(self, other: VarCompatible) -> Self: # type:ignore[misc] return self._inplace_binary_op(other, operator.ior) - def __ilshift__(self, other): + def __ilshift__(self, other: VarCompatible) -> Self: # type:ignore[misc] return self._inplace_binary_op(other, operator.ilshift) - def __irshift__(self, other): + def __irshift__(self, other: VarCompatible) -> Self: # type:ignore[misc] return self._inplace_binary_op(other, operator.irshift) - def _unary_op(self, f, *args, **kwargs): + def _unary_op(self, f: Callable, *args: Any, **kwargs: Any) -> Self: raise NotImplementedError - def __neg__(self): + def __neg__(self) -> Self: return self._unary_op(operator.neg) - def __pos__(self): + def __pos__(self) -> Self: return self._unary_op(operator.pos) - def __abs__(self): + def __abs__(self) -> Self: return self._unary_op(operator.abs) - def __invert__(self): + def __invert__(self) -> Self: return self._unary_op(operator.invert) - def round(self, *args, **kwargs): + def round(self, *args: Any, **kwargs: Any) -> Self: return self._unary_op(ops.round_, *args, **kwargs) - def argsort(self, *args, **kwargs): + def argsort(self, *args: Any, **kwargs: Any) -> Self: return self._unary_op(ops.argsort, *args, **kwargs) - def conj(self, *args, **kwargs): + def conj(self, *args: Any, **kwargs: Any) -> Self: return self._unary_op(ops.conj, *args, **kwargs) - def conjugate(self, *args, **kwargs): + def conjugate(self, *args: Any, **kwargs: Any) -> Self: return self._unary_op(ops.conjugate, *args, **kwargs) __add__.__doc__ = operator.add.__doc__ @@ -627,91 +792,93 @@ def conjugate(self, *args, **kwargs): class DatasetGroupByOpsMixin: __slots__ = () - def _binary_op(self, other, f, reflexive=False): + def _binary_op( + self, other: GroupByCompatible, f: Callable, reflexive: bool = False + ) -> Dataset: raise NotImplementedError - def __add__(self, other): + def __add__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.add) - def __sub__(self, other): + def __sub__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.sub) - def __mul__(self, other): + def __mul__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.mul) - def __pow__(self, other): + def __pow__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.pow) - def __truediv__(self, other): + def __truediv__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.truediv) - def __floordiv__(self, other): + def __floordiv__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.floordiv) - def __mod__(self, other): + def __mod__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.mod) - def __and__(self, other): + def __and__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.and_) - def __xor__(self, other): + def __xor__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.xor) - def __or__(self, other): + def __or__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.or_) - def __lshift__(self, other): + def __lshift__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.lshift) - def __rshift__(self, other): + def __rshift__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.rshift) - def __lt__(self, other): + def __lt__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.lt) - def __le__(self, other): + def __le__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.le) - def __gt__(self, other): + def __gt__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.gt) - def __ge__(self, other): + def __ge__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.ge) - def __eq__(self, other): + def __eq__(self, other: GroupByCompatible) -> Dataset: # type:ignore[override] return self._binary_op(other, nputils.array_eq) - def __ne__(self, other): + def __ne__(self, other: GroupByCompatible) -> Dataset: # type:ignore[override] return self._binary_op(other, nputils.array_ne) - def __radd__(self, other): + def __radd__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.add, reflexive=True) - def __rsub__(self, other): + def __rsub__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.sub, reflexive=True) - def __rmul__(self, other): + def __rmul__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.mul, reflexive=True) - def __rpow__(self, other): + def __rpow__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.pow, reflexive=True) - def __rtruediv__(self, other): + def __rtruediv__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.truediv, reflexive=True) - def __rfloordiv__(self, other): + def __rfloordiv__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.floordiv, reflexive=True) - def __rmod__(self, other): + def __rmod__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.mod, reflexive=True) - def __rand__(self, other): + def __rand__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.and_, reflexive=True) - def __rxor__(self, other): + def __rxor__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.xor, reflexive=True) - def __ror__(self, other): + def __ror__(self, other: GroupByCompatible) -> Dataset: return self._binary_op(other, operator.or_, reflexive=True) __add__.__doc__ = operator.add.__doc__ @@ -747,91 +914,93 @@ def __ror__(self, other): class DataArrayGroupByOpsMixin: __slots__ = () - def _binary_op(self, other, f, reflexive=False): + def _binary_op( + self, other: T_Xarray, f: Callable, reflexive: bool = False + ) -> T_Xarray: raise NotImplementedError - def __add__(self, other): + def __add__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.add) - def __sub__(self, other): + def __sub__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.sub) - def __mul__(self, other): + def __mul__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.mul) - def __pow__(self, other): + def __pow__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.pow) - def __truediv__(self, other): + def __truediv__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.truediv) - def __floordiv__(self, other): + def __floordiv__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.floordiv) - def __mod__(self, other): + def __mod__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.mod) - def __and__(self, other): + def __and__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.and_) - def __xor__(self, other): + def __xor__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.xor) - def __or__(self, other): + def __or__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.or_) - def __lshift__(self, other): + def __lshift__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.lshift) - def __rshift__(self, other): + def __rshift__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.rshift) - def __lt__(self, other): + def __lt__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.lt) - def __le__(self, other): + def __le__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.le) - def __gt__(self, other): + def __gt__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.gt) - def __ge__(self, other): + def __ge__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.ge) - def __eq__(self, other): + def __eq__(self, other: T_Xarray) -> T_Xarray: # type:ignore[override] return self._binary_op(other, nputils.array_eq) - def __ne__(self, other): + def __ne__(self, other: T_Xarray) -> T_Xarray: # type:ignore[override] return self._binary_op(other, nputils.array_ne) - def __radd__(self, other): + def __radd__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.add, reflexive=True) - def __rsub__(self, other): + def __rsub__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.sub, reflexive=True) - def __rmul__(self, other): + def __rmul__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.mul, reflexive=True) - def __rpow__(self, other): + def __rpow__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.pow, reflexive=True) - def __rtruediv__(self, other): + def __rtruediv__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.truediv, reflexive=True) - def __rfloordiv__(self, other): + def __rfloordiv__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.floordiv, reflexive=True) - def __rmod__(self, other): + def __rmod__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.mod, reflexive=True) - def __rand__(self, other): + def __rand__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.and_, reflexive=True) - def __rxor__(self, other): + def __rxor__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.xor, reflexive=True) - def __ror__(self, other): + def __ror__(self, other: T_Xarray) -> T_Xarray: return self._binary_op(other, operator.or_, reflexive=True) __add__.__doc__ = operator.add.__doc__ diff --git a/xarray/core/_typed_ops.pyi b/xarray/core/_typed_ops.pyi deleted file mode 100644 index 9e2ba2d3a06..00000000000 --- a/xarray/core/_typed_ops.pyi +++ /dev/null @@ -1,782 +0,0 @@ -"""Stub file for mixin classes with arithmetic operators.""" -# This file was generated using xarray.util.generate_ops. Do not edit manually. - -from typing import NoReturn, TypeVar, overload - -import numpy as np -from numpy.typing import ArrayLike - -from .dataarray import DataArray -from .dataset import Dataset -from .groupby import DataArrayGroupBy, DatasetGroupBy, GroupBy -from .types import ( - DaCompatible, - DsCompatible, - GroupByIncompatible, - ScalarOrArray, - VarCompatible, -) -from .variable import Variable - -try: - from dask.array import Array as DaskArray -except ImportError: - DaskArray = np.ndarray # type: ignore - -# DatasetOpsMixin etc. are parent classes of Dataset etc. -# Because of https://github.com/pydata/xarray/issues/5755, we redefine these. Generally -# we use the ones in `types`. (We're open to refining this, and potentially integrating -# the `py` & `pyi` files to simplify them.) -T_Dataset = TypeVar("T_Dataset", bound="DatasetOpsMixin") -T_DataArray = TypeVar("T_DataArray", bound="DataArrayOpsMixin") -T_Variable = TypeVar("T_Variable", bound="VariableOpsMixin") - -class DatasetOpsMixin: - __slots__ = () - def _binary_op(self, other, f, reflexive=...): ... - def __add__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __sub__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __mul__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __pow__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __truediv__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __floordiv__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __mod__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __and__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __xor__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __or__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __lshift__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __rshift__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __lt__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __le__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __gt__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __ge__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __eq__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... # type: ignore[override] - def __ne__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... # type: ignore[override] - def __radd__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __rsub__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __rmul__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __rpow__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __rtruediv__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __rfloordiv__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __rmod__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __rand__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __rxor__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def __ror__(self: T_Dataset, other: DsCompatible) -> T_Dataset: ... - def _inplace_binary_op(self, other, f): ... - def _unary_op(self, f, *args, **kwargs): ... - def __neg__(self: T_Dataset) -> T_Dataset: ... - def __pos__(self: T_Dataset) -> T_Dataset: ... - def __abs__(self: T_Dataset) -> T_Dataset: ... - def __invert__(self: T_Dataset) -> T_Dataset: ... - def round(self: T_Dataset, *args, **kwargs) -> T_Dataset: ... - def argsort(self: T_Dataset, *args, **kwargs) -> T_Dataset: ... - def conj(self: T_Dataset, *args, **kwargs) -> T_Dataset: ... - def conjugate(self: T_Dataset, *args, **kwargs) -> T_Dataset: ... - -class DataArrayOpsMixin: - __slots__ = () - def _binary_op(self, other, f, reflexive=...): ... - @overload - def __add__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __add__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __add__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __sub__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __sub__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __sub__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __mul__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __mul__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __mul__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __pow__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __pow__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __pow__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __truediv__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __truediv__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __truediv__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __floordiv__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __floordiv__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __floordiv__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __mod__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __mod__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __mod__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __and__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __and__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __and__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __xor__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __xor__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __xor__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __or__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __or__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __or__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __lshift__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __lshift__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __lshift__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __rshift__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rshift__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __rshift__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __lt__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __lt__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __lt__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __le__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __le__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __le__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __gt__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __gt__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __gt__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __ge__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __ge__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __ge__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload # type: ignore[override] - def __eq__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __eq__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __eq__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload # type: ignore[override] - def __ne__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __ne__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __ne__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __radd__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __radd__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __radd__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __rsub__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rsub__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __rsub__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __rmul__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rmul__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __rmul__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __rpow__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rpow__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __rpow__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __rtruediv__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rtruediv__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __rtruediv__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __rfloordiv__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rfloordiv__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __rfloordiv__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __rmod__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rmod__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __rmod__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __rand__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rand__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __rand__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __rxor__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rxor__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __rxor__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - @overload - def __ror__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __ror__(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def __ror__(self: T_DataArray, other: DaCompatible) -> T_DataArray: ... - def _inplace_binary_op(self, other, f): ... - def _unary_op(self, f, *args, **kwargs): ... - def __neg__(self: T_DataArray) -> T_DataArray: ... - def __pos__(self: T_DataArray) -> T_DataArray: ... - def __abs__(self: T_DataArray) -> T_DataArray: ... - def __invert__(self: T_DataArray) -> T_DataArray: ... - def round(self: T_DataArray, *args, **kwargs) -> T_DataArray: ... - def argsort(self: T_DataArray, *args, **kwargs) -> T_DataArray: ... - def conj(self: T_DataArray, *args, **kwargs) -> T_DataArray: ... - def conjugate(self: T_DataArray, *args, **kwargs) -> T_DataArray: ... - -class VariableOpsMixin: - __slots__ = () - def _binary_op(self, other, f, reflexive=...): ... - @overload - def __add__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __add__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __add__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __sub__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __sub__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __sub__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __mul__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __mul__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __mul__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __pow__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __pow__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __pow__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __truediv__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __truediv__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __truediv__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __floordiv__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __floordiv__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __floordiv__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __mod__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __mod__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __mod__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __and__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __and__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __and__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __xor__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __xor__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __xor__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __or__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __or__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __or__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __lshift__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __lshift__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __lshift__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __rshift__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rshift__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __rshift__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __lt__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __lt__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __lt__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __le__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __le__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __le__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __gt__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __gt__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __gt__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __ge__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __ge__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __ge__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload # type: ignore[override] - def __eq__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __eq__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __eq__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload # type: ignore[override] - def __ne__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __ne__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __ne__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __radd__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __radd__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __radd__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __rsub__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rsub__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __rsub__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __rmul__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rmul__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __rmul__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __rpow__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rpow__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __rpow__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __rtruediv__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rtruediv__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __rtruediv__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __rfloordiv__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rfloordiv__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __rfloordiv__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __rmod__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rmod__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __rmod__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __rand__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rand__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __rand__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __rxor__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rxor__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __rxor__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - @overload - def __ror__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __ror__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __ror__(self: T_Variable, other: VarCompatible) -> T_Variable: ... - def _inplace_binary_op(self, other, f): ... - def _unary_op(self, f, *args, **kwargs): ... - def __neg__(self: T_Variable) -> T_Variable: ... - def __pos__(self: T_Variable) -> T_Variable: ... - def __abs__(self: T_Variable) -> T_Variable: ... - def __invert__(self: T_Variable) -> T_Variable: ... - def round(self: T_Variable, *args, **kwargs) -> T_Variable: ... - def argsort(self: T_Variable, *args, **kwargs) -> T_Variable: ... - def conj(self: T_Variable, *args, **kwargs) -> T_Variable: ... - def conjugate(self: T_Variable, *args, **kwargs) -> T_Variable: ... - -class DatasetGroupByOpsMixin: - __slots__ = () - def _binary_op(self, other, f, reflexive=...): ... - @overload - def __add__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __add__(self, other: "DataArray") -> "Dataset": ... - @overload - def __add__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __sub__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __sub__(self, other: "DataArray") -> "Dataset": ... - @overload - def __sub__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __mul__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __mul__(self, other: "DataArray") -> "Dataset": ... - @overload - def __mul__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __pow__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __pow__(self, other: "DataArray") -> "Dataset": ... - @overload - def __pow__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __truediv__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __truediv__(self, other: "DataArray") -> "Dataset": ... - @overload - def __truediv__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __floordiv__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __floordiv__(self, other: "DataArray") -> "Dataset": ... - @overload - def __floordiv__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __mod__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __mod__(self, other: "DataArray") -> "Dataset": ... - @overload - def __mod__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __and__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __and__(self, other: "DataArray") -> "Dataset": ... - @overload - def __and__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __xor__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __xor__(self, other: "DataArray") -> "Dataset": ... - @overload - def __xor__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __or__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __or__(self, other: "DataArray") -> "Dataset": ... - @overload - def __or__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __lshift__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __lshift__(self, other: "DataArray") -> "Dataset": ... - @overload - def __lshift__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __rshift__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rshift__(self, other: "DataArray") -> "Dataset": ... - @overload - def __rshift__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __lt__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __lt__(self, other: "DataArray") -> "Dataset": ... - @overload - def __lt__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __le__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __le__(self, other: "DataArray") -> "Dataset": ... - @overload - def __le__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __gt__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __gt__(self, other: "DataArray") -> "Dataset": ... - @overload - def __gt__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __ge__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __ge__(self, other: "DataArray") -> "Dataset": ... - @overload - def __ge__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload # type: ignore[override] - def __eq__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __eq__(self, other: "DataArray") -> "Dataset": ... - @overload - def __eq__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload # type: ignore[override] - def __ne__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __ne__(self, other: "DataArray") -> "Dataset": ... - @overload - def __ne__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __radd__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __radd__(self, other: "DataArray") -> "Dataset": ... - @overload - def __radd__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __rsub__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rsub__(self, other: "DataArray") -> "Dataset": ... - @overload - def __rsub__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __rmul__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rmul__(self, other: "DataArray") -> "Dataset": ... - @overload - def __rmul__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __rpow__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rpow__(self, other: "DataArray") -> "Dataset": ... - @overload - def __rpow__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __rtruediv__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rtruediv__(self, other: "DataArray") -> "Dataset": ... - @overload - def __rtruediv__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __rfloordiv__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rfloordiv__(self, other: "DataArray") -> "Dataset": ... - @overload - def __rfloordiv__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __rmod__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rmod__(self, other: "DataArray") -> "Dataset": ... - @overload - def __rmod__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __rand__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rand__(self, other: "DataArray") -> "Dataset": ... - @overload - def __rand__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __rxor__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rxor__(self, other: "DataArray") -> "Dataset": ... - @overload - def __rxor__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __ror__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __ror__(self, other: "DataArray") -> "Dataset": ... - @overload - def __ror__(self, other: GroupByIncompatible) -> NoReturn: ... - -class DataArrayGroupByOpsMixin: - __slots__ = () - def _binary_op(self, other, f, reflexive=...): ... - @overload - def __add__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __add__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __add__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __sub__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __sub__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __sub__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __mul__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __mul__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __mul__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __pow__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __pow__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __pow__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __truediv__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __truediv__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __truediv__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __floordiv__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __floordiv__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __floordiv__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __mod__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __mod__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __mod__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __and__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __and__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __and__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __xor__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __xor__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __xor__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __or__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __or__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __or__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __lshift__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __lshift__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __lshift__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __rshift__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rshift__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __rshift__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __lt__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __lt__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __lt__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __le__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __le__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __le__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __gt__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __gt__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __gt__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __ge__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __ge__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __ge__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload # type: ignore[override] - def __eq__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __eq__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __eq__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload # type: ignore[override] - def __ne__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __ne__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __ne__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __radd__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __radd__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __radd__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __rsub__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rsub__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __rsub__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __rmul__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rmul__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __rmul__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __rpow__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rpow__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __rpow__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __rtruediv__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rtruediv__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __rtruediv__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __rfloordiv__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rfloordiv__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __rfloordiv__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __rmod__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rmod__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __rmod__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __rand__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rand__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __rand__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __rxor__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __rxor__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __rxor__(self, other: GroupByIncompatible) -> NoReturn: ... - @overload - def __ror__(self, other: T_Dataset) -> T_Dataset: ... - @overload - def __ror__(self, other: T_DataArray) -> T_DataArray: ... - @overload - def __ror__(self, other: GroupByIncompatible) -> NoReturn: ... diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 724a5fc2580..0b9786dc2b7 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -4618,25 +4618,22 @@ def _unary_op(self, f: Callable, *args, **kwargs) -> Self: return da def _binary_op( - self, - other: T_Xarray, - f: Callable, - reflexive: bool = False, - ) -> T_Xarray: + self, other: DaCompatible, f: Callable, reflexive: bool = False + ) -> Self: from xarray.core.groupby import GroupBy if isinstance(other, (Dataset, GroupBy)): return NotImplemented if isinstance(other, DataArray): align_type = OPTIONS["arithmetic_join"] - self, other = align(self, other, join=align_type, copy=False) - other_variable = getattr(other, "variable", other) + self, other = align(self, other, join=align_type, copy=False) # type: ignore[type-var,assignment] + other_variable_or_arraylike: DaCompatible = getattr(other, "variable", other) other_coords = getattr(other, "coords", None) variable = ( - f(self.variable, other_variable) + f(self.variable, other_variable_or_arraylike) if not reflexive - else f(other_variable, self.variable) + else f(other_variable_or_arraylike, self.variable) ) coords, indexes = self.coords._merge_raw(other_coords, reflexive) name = self._result_name(other) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 44016e87306..d24a62414ea 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -1475,13 +1475,20 @@ def __bool__(self) -> bool: def __iter__(self) -> Iterator[Hashable]: return iter(self.data_vars) - def __array__(self, dtype=None): - raise TypeError( - "cannot directly convert an xarray.Dataset into a " - "numpy array. Instead, create an xarray.DataArray " - "first, either with indexing on the Dataset or by " - "invoking the `to_array()` method." - ) + if TYPE_CHECKING: + # needed because __getattr__ is returning Any and otherwise + # this class counts as part of the SupportsArray Protocol + __array__ = None + + else: + + def __array__(self, dtype=None): + raise TypeError( + "cannot directly convert an xarray.Dataset into a " + "numpy array. Instead, create an xarray.DataArray " + "first, either with indexing on the Dataset or by " + "invoking the `to_array()` method." + ) @property def nbytes(self) -> int: diff --git a/xarray/core/rolling.py b/xarray/core/rolling.py index c6911cbe65b..b85092982e3 100644 --- a/xarray/core/rolling.py +++ b/xarray/core/rolling.py @@ -61,6 +61,11 @@ class Rolling(Generic[T_Xarray]): __slots__ = ("obj", "window", "min_periods", "center", "dim") _attributes = ("window", "min_periods", "center", "dim") + dim: list[Hashable] + window: list[int] + center: list[bool] + obj: T_Xarray + min_periods: int def __init__( self, @@ -91,8 +96,8 @@ def __init__( ------- rolling : type of input argument """ - self.dim: list[Hashable] = [] - self.window: list[int] = [] + self.dim = [] + self.window = [] for d, w in windows.items(): self.dim.append(d) if w <= 0: @@ -100,7 +105,7 @@ def __init__( self.window.append(w) self.center = self._mapping_to_list(center, default=False) - self.obj: T_Xarray = obj + self.obj = obj missing_dims = tuple(dim for dim in self.dim if dim not in self.obj.dims) if missing_dims: @@ -814,6 +819,10 @@ class Coarsen(CoarsenArithmetic, Generic[T_Xarray]): ) _attributes = ("windows", "side", "trim_excess") obj: T_Xarray + windows: Mapping[Hashable, int] + side: SideOptions | Mapping[Hashable, SideOptions] + boundary: CoarsenBoundaryOptions + coord_func: Mapping[Hashable, str | Callable] def __init__( self, @@ -855,12 +864,15 @@ def __init__( f"Window dimensions {missing_dims} not found in {self.obj.__class__.__name__} " f"dimensions {tuple(self.obj.dims)}" ) - if not utils.is_dict_like(coord_func): - coord_func = {d: coord_func for d in self.obj.dims} # type: ignore[misc] + + if utils.is_dict_like(coord_func): + coord_func_map = coord_func + else: + coord_func_map = {d: coord_func for d in self.obj.dims} for c in self.obj.coords: - if c not in coord_func: - coord_func[c] = duck_array_ops.mean # type: ignore[index] - self.coord_func: Mapping[Hashable, str | Callable] = coord_func + if c not in coord_func_map: + coord_func_map[c] = duck_array_ops.mean # type: ignore[index] + self.coord_func = coord_func_map def _get_keep_attrs(self, keep_attrs): if keep_attrs is None: diff --git a/xarray/core/types.py b/xarray/core/types.py index 6b6f9300631..073121b13b1 100644 --- a/xarray/core/types.py +++ b/xarray/core/types.py @@ -38,7 +38,6 @@ from xarray.core.coordinates import Coordinates from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset - from xarray.core.groupby import DataArrayGroupBy, GroupBy from xarray.core.indexes import Index, Indexes from xarray.core.utils import Frozen from xarray.core.variable import Variable @@ -176,10 +175,10 @@ def copy( T_DuckArray = TypeVar("T_DuckArray", bound=Any) ScalarOrArray = Union["ArrayLike", np.generic, np.ndarray, "DaskArray"] -DsCompatible = Union["Dataset", "DataArray", "Variable", "GroupBy", "ScalarOrArray"] -DaCompatible = Union["DataArray", "Variable", "DataArrayGroupBy", "ScalarOrArray"] VarCompatible = Union["Variable", "ScalarOrArray"] -GroupByIncompatible = Union["Variable", "GroupBy"] +DaCompatible = Union["DataArray", "VarCompatible"] +DsCompatible = Union["Dataset", "DaCompatible"] +GroupByCompatible = Union["Dataset", "DataArray"] Dims = Union[str, Iterable[Hashable], "ellipsis", None] OrderedDims = Union[str, Sequence[Union[Hashable, "ellipsis"]], "ellipsis", None] diff --git a/xarray/core/weighted.py b/xarray/core/weighted.py index 82ffe684ec7..b1ea1ee625c 100644 --- a/xarray/core/weighted.py +++ b/xarray/core/weighted.py @@ -324,6 +324,7 @@ def _weighted_quantile( def _get_h(n: float, q: np.ndarray, method: QUANTILE_METHODS) -> np.ndarray: """Return the interpolation parameter.""" # Note that options are not yet exposed in the public API. + h: np.ndarray if method == "linear": h = (n - 1) * q + 1 elif method == "interpolated_inverted_cdf": diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index e143e2b8e03..320ba999318 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -810,9 +810,9 @@ def test_groupby_math_more() -> None: with pytest.raises(TypeError, match=r"only support binary ops"): grouped + 1 # type: ignore[operator] with pytest.raises(TypeError, match=r"only support binary ops"): - grouped + grouped + grouped + grouped # type: ignore[operator] with pytest.raises(TypeError, match=r"in-place operations"): - ds += grouped + ds += grouped # type: ignore[arg-type] ds = Dataset( { diff --git a/xarray/tests/test_typed_ops.py b/xarray/tests/test_typed_ops.py new file mode 100644 index 00000000000..1d4ef89ae29 --- /dev/null +++ b/xarray/tests/test_typed_ops.py @@ -0,0 +1,246 @@ +import numpy as np + +from xarray import DataArray, Dataset, Variable + + +def test_variable_typed_ops() -> None: + """Tests for type checking of typed_ops on Variable""" + + var = Variable(dims=["t"], data=[1, 2, 3]) + + def _test(var: Variable) -> None: + # mypy checks the input type + assert isinstance(var, Variable) + + _int: int = 1 + _list = [1, 2, 3] + _ndarray = np.array([1, 2, 3]) + + # __add__ as an example of binary ops + _test(var + _int) + _test(var + _list) + _test(var + _ndarray) + _test(var + var) + + # __radd__ as an example of reflexive binary ops + _test(_int + var) + _test(_list + var) + _test(_ndarray + var) # type: ignore[arg-type] # numpy problem + + # __eq__ as an example of cmp ops + _test(var == _int) + _test(var == _list) + _test(var == _ndarray) + _test(_int == var) # type: ignore[arg-type] # typeshed problem + _test(_list == var) # type: ignore[arg-type] # typeshed problem + _test(_ndarray == var) + + # __lt__ as another example of cmp ops + _test(var < _int) + _test(var < _list) + _test(var < _ndarray) + _test(_int > var) + _test(_list > var) + _test(_ndarray > var) # type: ignore[arg-type] # numpy problem + + # __iadd__ as an example of inplace binary ops + var += _int + var += _list + var += _ndarray + + # __neg__ as an example of unary ops + _test(-var) + + +def test_dataarray_typed_ops() -> None: + """Tests for type checking of typed_ops on DataArray""" + + da = DataArray([1, 2, 3], dims=["t"]) + + def _test(da: DataArray) -> None: + # mypy checks the input type + assert isinstance(da, DataArray) + + _int: int = 1 + _list = [1, 2, 3] + _ndarray = np.array([1, 2, 3]) + _var = Variable(dims=["t"], data=[1, 2, 3]) + + # __add__ as an example of binary ops + _test(da + _int) + _test(da + _list) + _test(da + _ndarray) + _test(da + _var) + _test(da + da) + + # __radd__ as an example of reflexive binary ops + _test(_int + da) + _test(_list + da) + _test(_ndarray + da) # type: ignore[arg-type] # numpy problem + _test(_var + da) + + # __eq__ as an example of cmp ops + _test(da == _int) + _test(da == _list) + _test(da == _ndarray) + _test(da == _var) + _test(_int == da) # type: ignore[arg-type] # typeshed problem + _test(_list == da) # type: ignore[arg-type] # typeshed problem + _test(_ndarray == da) + _test(_var == da) + + # __lt__ as another example of cmp ops + _test(da < _int) + _test(da < _list) + _test(da < _ndarray) + _test(da < _var) + _test(_int > da) + _test(_list > da) + _test(_ndarray > da) # type: ignore[arg-type] # numpy problem + _test(_var > da) + + # __iadd__ as an example of inplace binary ops + da += _int + da += _list + da += _ndarray + da += _var + + # __neg__ as an example of unary ops + _test(-da) + + +def test_dataset_typed_ops() -> None: + """Tests for type checking of typed_ops on Dataset""" + + ds = Dataset({"a": ("t", [1, 2, 3])}) + + def _test(ds: Dataset) -> None: + # mypy checks the input type + assert isinstance(ds, Dataset) + + _int: int = 1 + _list = [1, 2, 3] + _ndarray = np.array([1, 2, 3]) + _var = Variable(dims=["t"], data=[1, 2, 3]) + _da = DataArray([1, 2, 3], dims=["t"]) + + # __add__ as an example of binary ops + _test(ds + _int) + _test(ds + _list) + _test(ds + _ndarray) + _test(ds + _var) + _test(ds + _da) + _test(ds + ds) + + # __radd__ as an example of reflexive binary ops + _test(_int + ds) + _test(_list + ds) + _test(_ndarray + ds) + _test(_var + ds) + _test(_da + ds) + + # __eq__ as an example of cmp ops + _test(ds == _int) + _test(ds == _list) + _test(ds == _ndarray) + _test(ds == _var) + _test(ds == _da) + _test(_int == ds) # type: ignore[arg-type] # typeshed problem + _test(_list == ds) # type: ignore[arg-type] # typeshed problem + _test(_ndarray == ds) + _test(_var == ds) + _test(_da == ds) + + # __lt__ as another example of cmp ops + _test(ds < _int) + _test(ds < _list) + _test(ds < _ndarray) + _test(ds < _var) + _test(ds < _da) + _test(_int > ds) + _test(_list > ds) + _test(_ndarray > ds) # type: ignore[arg-type] # numpy problem + _test(_var > ds) + _test(_da > ds) + + # __iadd__ as an example of inplace binary ops + ds += _int + ds += _list + ds += _ndarray + ds += _var + ds += _da + + # __neg__ as an example of unary ops + _test(-ds) + + +def test_dataarray_groupy_typed_ops() -> None: + """Tests for type checking of typed_ops on DataArrayGroupBy""" + + da = DataArray([1, 2, 3], coords={"x": ("t", [1, 2, 2])}, dims=["t"]) + grp = da.groupby("x") + + def _testda(da: DataArray) -> None: + # mypy checks the input type + assert isinstance(da, DataArray) + + def _testds(ds: Dataset) -> None: + # mypy checks the input type + assert isinstance(ds, Dataset) + + _da = DataArray([5, 6], coords={"x": [1, 2]}, dims="x") + _ds = _da.to_dataset(name="a") + + # __add__ as an example of binary ops + _testda(grp + _da) + _testds(grp + _ds) + + # __radd__ as an example of reflexive binary ops + _testda(_da + grp) + _testds(_ds + grp) + + # __eq__ as an example of cmp ops + _testda(grp == _da) + _testda(_da == grp) + _testds(grp == _ds) + _testds(_ds == grp) + + # __lt__ as another example of cmp ops + _testda(grp < _da) + _testda(_da > grp) + _testds(grp < _ds) + _testds(_ds > grp) + + +def test_dataset_groupy_typed_ops() -> None: + """Tests for type checking of typed_ops on DatasetGroupBy""" + + ds = Dataset({"a": ("t", [1, 2, 3])}, coords={"x": ("t", [1, 2, 2])}) + grp = ds.groupby("x") + + def _test(ds: Dataset) -> None: + # mypy checks the input type + assert isinstance(ds, Dataset) + + _da = DataArray([5, 6], coords={"x": [1, 2]}, dims="x") + _ds = _da.to_dataset(name="a") + + # __add__ as an example of binary ops + _test(grp + _da) + _test(grp + _ds) + + # __radd__ as an example of reflexive binary ops + _test(_da + grp) + _test(_ds + grp) + + # __eq__ as an example of cmp ops + _test(grp == _da) + _test(_da == grp) + _test(grp == _ds) + _test(_ds == grp) + + # __lt__ as another example of cmp ops + _test(grp < _da) + _test(_da > grp) + _test(grp < _ds) + _test(_ds > grp) diff --git a/xarray/util/generate_ops.py b/xarray/util/generate_ops.py index cf0673e7cca..632ca06d295 100644 --- a/xarray/util/generate_ops.py +++ b/xarray/util/generate_ops.py @@ -3,14 +3,16 @@ For internal xarray development use only. Usage: - python xarray/util/generate_ops.py --module > xarray/core/_typed_ops.py - python xarray/util/generate_ops.py --stubs > xarray/core/_typed_ops.pyi + python xarray/util/generate_ops.py > xarray/core/_typed_ops.py """ # Note: the comments in https://github.com/pydata/xarray/pull/4904 provide some # background to some of the design choices made here. -import sys +from __future__ import annotations + +from collections.abc import Iterator, Sequence +from typing import Optional BINOPS_EQNE = (("__eq__", "nputils.array_eq"), ("__ne__", "nputils.array_ne")) BINOPS_CMP = ( @@ -74,155 +76,178 @@ ("conjugate", "ops.conjugate"), ) + +required_method_binary = """ + def _binary_op( + self, other: {other_type}, f: Callable, reflexive: bool = False + ) -> {return_type}: + raise NotImplementedError""" template_binop = """ - def {method}(self, other): + def {method}(self, other: {other_type}) -> {return_type}:{type_ignore} return self._binary_op(other, {func})""" +template_binop_overload = """ + @overload{overload_type_ignore} + def {method}(self, other: {overload_type}) -> NoReturn: + ... + + @overload + def {method}(self, other: {other_type}) -> {return_type}: + ... +""" template_reflexive = """ - def {method}(self, other): + def {method}(self, other: {other_type}) -> {return_type}: return self._binary_op(other, {func}, reflexive=True)""" + +required_method_inplace = """ + def _inplace_binary_op(self, other: {other_type}, f: Callable) -> Self: + raise NotImplementedError""" template_inplace = """ - def {method}(self, other): + def {method}(self, other: {other_type}) -> Self:{type_ignore} return self._inplace_binary_op(other, {func})""" + +required_method_unary = """ + def _unary_op(self, f: Callable, *args: Any, **kwargs: Any) -> Self: + raise NotImplementedError""" template_unary = """ - def {method}(self): + def {method}(self) -> Self: return self._unary_op({func})""" template_other_unary = """ - def {method}(self, *args, **kwargs): + def {method}(self, *args: Any, **kwargs: Any) -> Self: return self._unary_op({func}, *args, **kwargs)""" -required_method_unary = """ - def _unary_op(self, f, *args, **kwargs): - raise NotImplementedError""" -required_method_binary = """ - def _binary_op(self, other, f, reflexive=False): - raise NotImplementedError""" -required_method_inplace = """ - def _inplace_binary_op(self, other, f): - raise NotImplementedError""" # For some methods we override return type `bool` defined by base class `object`. -OVERRIDE_TYPESHED = {"override": " # type: ignore[override]"} -NO_OVERRIDE = {"override": ""} - -# Note: in some of the overloads below the return value in reality is NotImplemented, -# which cannot accurately be expressed with type hints,e.g. Literal[NotImplemented] -# or type(NotImplemented) are not allowed and NoReturn has a different meaning. -# In such cases we are lending the type checkers a hand by specifying the return type -# of the corresponding reflexive method on `other` which will be called instead. -stub_ds = """\ - def {method}(self: T_Dataset, other: DsCompatible) -> T_Dataset: ...{override}""" -stub_da = """\ - @overload{override} - def {method}(self, other: T_Dataset) -> T_Dataset: ... - @overload - def {method}(self, other: "DatasetGroupBy") -> "Dataset": ... - @overload - def {method}(self: T_DataArray, other: DaCompatible) -> T_DataArray: ...""" -stub_var = """\ - @overload{override} - def {method}(self, other: T_Dataset) -> T_Dataset: ... - @overload - def {method}(self, other: T_DataArray) -> T_DataArray: ... - @overload - def {method}(self: T_Variable, other: VarCompatible) -> T_Variable: ...""" -stub_dsgb = """\ - @overload{override} - def {method}(self, other: T_Dataset) -> T_Dataset: ... - @overload - def {method}(self, other: "DataArray") -> "Dataset": ... - @overload - def {method}(self, other: GroupByIncompatible) -> NoReturn: ...""" -stub_dagb = """\ - @overload{override} - def {method}(self, other: T_Dataset) -> T_Dataset: ... - @overload - def {method}(self, other: T_DataArray) -> T_DataArray: ... - @overload - def {method}(self, other: GroupByIncompatible) -> NoReturn: ...""" -stub_unary = """\ - def {method}(self: {self_type}) -> {self_type}: ...""" -stub_other_unary = """\ - def {method}(self: {self_type}, *args, **kwargs) -> {self_type}: ...""" -stub_required_unary = """\ - def _unary_op(self, f, *args, **kwargs): ...""" -stub_required_binary = """\ - def _binary_op(self, other, f, reflexive=...): ...""" -stub_required_inplace = """\ - def _inplace_binary_op(self, other, f): ...""" - - -def unops(self_type): - extra_context = {"self_type": self_type} +# We need to add "# type: ignore[override]" +# Keep an eye out for: +# https://discuss.python.org/t/make-type-hints-for-eq-of-primitives-less-strict/34240 +# The type ignores might not be neccesary anymore at some point. +# +# We require a "hack" to tell type checkers that e.g. Variable + DataArray = DataArray +# In reality this returns NotImplementes, but this is not a valid type in python 3.9. +# Therefore, we use NoReturn which mypy seems to recognise! +# TODO: change once python 3.10 is the minimum. +# +# Mypy seems to require that __iadd__ and __add__ have the same signature. +# This requires some extra type: ignores[misc] in the inplace methods :/ + + +def _type_ignore(ignore: str) -> str: + return f" # type:ignore[{ignore}]" if ignore else "" + + +FuncType = Sequence[tuple[Optional[str], Optional[str]]] +OpsType = tuple[FuncType, str, dict[str, str]] + + +def binops( + other_type: str, return_type: str = "Self", type_ignore_eq: str = "override" +) -> list[OpsType]: + extras = {"other_type": other_type, "return_type": return_type} return [ - ([(None, None)], required_method_unary, stub_required_unary, {}), - (UNARY_OPS, template_unary, stub_unary, extra_context), - (OTHER_UNARY_METHODS, template_other_unary, stub_other_unary, extra_context), + ([(None, None)], required_method_binary, extras), + (BINOPS_NUM + BINOPS_CMP, template_binop, extras | {"type_ignore": ""}), + ( + BINOPS_EQNE, + template_binop, + extras | {"type_ignore": _type_ignore(type_ignore_eq)}, + ), + (BINOPS_REFLEXIVE, template_reflexive, extras), ] -def binops(stub=""): +def binops_overload( + other_type: str, + overload_type: str, + return_type: str = "Self", + type_ignore_eq: str = "override", +) -> list[OpsType]: + extras = {"other_type": other_type, "return_type": return_type} return [ - ([(None, None)], required_method_binary, stub_required_binary, {}), - (BINOPS_NUM + BINOPS_CMP, template_binop, stub, NO_OVERRIDE), - (BINOPS_EQNE, template_binop, stub, OVERRIDE_TYPESHED), - (BINOPS_REFLEXIVE, template_reflexive, stub, NO_OVERRIDE), + ([(None, None)], required_method_binary, extras), + ( + BINOPS_NUM + BINOPS_CMP, + template_binop_overload + template_binop, + extras + | { + "overload_type": overload_type, + "type_ignore": "", + "overload_type_ignore": "", + }, + ), + ( + BINOPS_EQNE, + template_binop_overload + template_binop, + extras + | { + "overload_type": overload_type, + "type_ignore": "", + "overload_type_ignore": _type_ignore(type_ignore_eq), + }, + ), + (BINOPS_REFLEXIVE, template_reflexive, extras), ] -def inplace(): +def inplace(other_type: str, type_ignore: str = "") -> list[OpsType]: + extras = {"other_type": other_type} return [ - ([(None, None)], required_method_inplace, stub_required_inplace, {}), - (BINOPS_INPLACE, template_inplace, "", {}), + ([(None, None)], required_method_inplace, extras), + ( + BINOPS_INPLACE, + template_inplace, + extras | {"type_ignore": _type_ignore(type_ignore)}, + ), + ] + + +def unops() -> list[OpsType]: + return [ + ([(None, None)], required_method_unary, {}), + (UNARY_OPS, template_unary, {}), + (OTHER_UNARY_METHODS, template_other_unary, {}), ] ops_info = {} -ops_info["DatasetOpsMixin"] = binops(stub_ds) + inplace() + unops("T_Dataset") -ops_info["DataArrayOpsMixin"] = binops(stub_da) + inplace() + unops("T_DataArray") -ops_info["VariableOpsMixin"] = binops(stub_var) + inplace() + unops("T_Variable") -ops_info["DatasetGroupByOpsMixin"] = binops(stub_dsgb) -ops_info["DataArrayGroupByOpsMixin"] = binops(stub_dagb) +ops_info["DatasetOpsMixin"] = ( + binops(other_type="DsCompatible") + inplace(other_type="DsCompatible") + unops() +) +ops_info["DataArrayOpsMixin"] = ( + binops(other_type="DaCompatible") + inplace(other_type="DaCompatible") + unops() +) +ops_info["VariableOpsMixin"] = ( + binops_overload(other_type="VarCompatible", overload_type="T_DataArray") + + inplace(other_type="VarCompatible", type_ignore="misc") + + unops() +) +ops_info["DatasetGroupByOpsMixin"] = binops( + other_type="GroupByCompatible", return_type="Dataset" +) +ops_info["DataArrayGroupByOpsMixin"] = binops( + other_type="T_Xarray", return_type="T_Xarray" +) MODULE_PREAMBLE = '''\ """Mixin classes with arithmetic operators.""" # This file was generated using xarray.util.generate_ops. Do not edit manually. -import operator - -from . import nputils, ops''' - -STUBFILE_PREAMBLE = '''\ -"""Stub file for mixin classes with arithmetic operators.""" -# This file was generated using xarray.util.generate_ops. Do not edit manually. - -from typing import NoReturn, TypeVar, overload +from __future__ import annotations -import numpy as np -from numpy.typing import ArrayLike +import operator +from typing import TYPE_CHECKING, Any, Callable, NoReturn, overload -from .dataarray import DataArray -from .dataset import Dataset -from .groupby import DataArrayGroupBy, DatasetGroupBy, GroupBy -from .types import ( +from xarray.core import nputils, ops +from xarray.core.types import ( DaCompatible, DsCompatible, - GroupByIncompatible, - ScalarOrArray, + GroupByCompatible, + Self, + T_DataArray, + T_Xarray, VarCompatible, ) -from .variable import Variable -try: - from dask.array import Array as DaskArray -except ImportError: - DaskArray = np.ndarray # type: ignore - -# DatasetOpsMixin etc. are parent classes of Dataset etc. -# Because of https://github.com/pydata/xarray/issues/5755, we redefine these. Generally -# we use the ones in `types`. (We're open to refining this, and potentially integrating -# the `py` & `pyi` files to simplify them.) -T_Dataset = TypeVar("T_Dataset", bound="DatasetOpsMixin") -T_DataArray = TypeVar("T_DataArray", bound="DataArrayOpsMixin") -T_Variable = TypeVar("T_Variable", bound="VariableOpsMixin")''' +if TYPE_CHECKING: + from xarray.core.dataset import Dataset''' CLASS_PREAMBLE = """{newline} @@ -233,35 +258,28 @@ class {cls_name}: {method}.__doc__ = {func}.__doc__""" -def render(ops_info, is_module): +def render(ops_info: dict[str, list[OpsType]]) -> Iterator[str]: """Render the module or stub file.""" - yield MODULE_PREAMBLE if is_module else STUBFILE_PREAMBLE + yield MODULE_PREAMBLE for cls_name, method_blocks in ops_info.items(): - yield CLASS_PREAMBLE.format(cls_name=cls_name, newline="\n" * is_module) - yield from _render_classbody(method_blocks, is_module) + yield CLASS_PREAMBLE.format(cls_name=cls_name, newline="\n") + yield from _render_classbody(method_blocks) -def _render_classbody(method_blocks, is_module): - for method_func_pairs, method_template, stub_template, extra in method_blocks: - template = method_template if is_module else stub_template +def _render_classbody(method_blocks: list[OpsType]) -> Iterator[str]: + for method_func_pairs, template, extra in method_blocks: if template: for method, func in method_func_pairs: yield template.format(method=method, func=func, **extra) - if is_module: - yield "" - for method_func_pairs, *_ in method_blocks: - for method, func in method_func_pairs: - if method and func: - yield COPY_DOCSTRING.format(method=method, func=func) + yield "" + for method_func_pairs, *_ in method_blocks: + for method, func in method_func_pairs: + if method and func: + yield COPY_DOCSTRING.format(method=method, func=func) if __name__ == "__main__": - option = sys.argv[1].lower() if len(sys.argv) == 2 else None - if option not in {"--module", "--stubs"}: - raise SystemExit(f"Usage: {sys.argv[0]} --module | --stubs") - is_module = option == "--module" - - for line in render(ops_info, is_module): + for line in render(ops_info): print(line) From bac90ab067d7437de875ec57d90d863169e70429 Mon Sep 17 00:00:00 2001 From: Justus Magin Date: Mon, 25 Sep 2023 06:46:48 +0200 Subject: [PATCH 15/46] adapt to NEP 51 (#8064) * convert string and bytes items to standard python types * [test-upstream] * modify the expected error message --- xarray/core/formatting.py | 2 ++ xarray/tests/test_dataset.py | 3 ++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/xarray/core/formatting.py b/xarray/core/formatting.py index 3bfe902f0a3..942bf5891ca 100644 --- a/xarray/core/formatting.py +++ b/xarray/core/formatting.py @@ -160,6 +160,8 @@ def format_item(x, timedelta_format=None, quote_strings=True): if isinstance(x, (np.timedelta64, timedelta)): return format_timedelta(x, timedelta_format=timedelta_format) elif isinstance(x, (str, bytes)): + if hasattr(x, "dtype"): + x = x.item() return repr(x) if quote_strings else x elif hasattr(x, "dtype") and np.issubdtype(x.dtype, np.floating): return f"{x.item():.4}" diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index c832663ecff..3fb29e01ebb 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -4116,7 +4116,8 @@ def test_setitem(self) -> None: data4[{"dim2": [2, 3]}] = data3["var1"][{"dim2": [3, 4]}].values data5 = data4.astype(str) data5["var4"] = data4["var1"] - err_msg = "could not convert string to float: 'a'" + # convert to `np.str_('a')` once `numpy<2.0` has been dropped + err_msg = "could not convert string to float: .*'a'.*" with pytest.raises(ValueError, match=err_msg): data5[{"dim2": 1}] = "a" From da647b06312bd93c3412ddd712bf7ecb52e3f28b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kai=20M=C3=BChlbauer?= Date: Mon, 25 Sep 2023 13:30:33 +0200 Subject: [PATCH 16/46] decode variable with mismatched coordinate attribute (#8195) * decode variable with mismatched coordinate attribute, warn/raise meaningful error * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update xarray/conventions.py Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com> * use set comparison as suggested by review * use emit_user_level_warning for all occurrences * fix typing and docstring * fix typing and docstring * silently ignore names of missing variables as per review * only decode if there is at least one variable matching a coordinate * fix typing * update docstring * Apply suggestions from code review Co-authored-by: Deepak Cherian * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add review suggestion Co-authored-by: Deepak Cherian * Fix test --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com> Co-authored-by: Deepak Cherian --- doc/whats-new.rst | 2 ++ xarray/backends/api.py | 6 ++++ xarray/conventions.py | 54 ++++++++++++++++---------------- xarray/tests/test_conventions.py | 27 +++++++++++++++- 4 files changed, 61 insertions(+), 28 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 4307c2829ca..c37a3213793 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -87,6 +87,8 @@ Bug fixes issues (:issue:`7817`, :issue:`7942`, :issue:`7790`, :issue:`6191`, :issue:`7096`, :issue:`1064`, :pull:`7827`). By `Kai Mühlbauer `_. +- Fixed a bug where inaccurate ``coordinates`` silently failed to decode variable (:issue:`1809`, :pull:`8195`). + By `Kai Mühlbauer `_ - ``.rolling_exp`` functions no longer mistakenly lose non-dimensioned coords (:issue:`6528`, :pull:`8114`) By `Maximilian Roos `_. diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 58a05aeddce..7ca4377e4cf 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -488,6 +488,9 @@ def open_dataset( as coordinate variables. - "all": Set variables referred to in ``'grid_mapping'``, ``'bounds'`` and other attributes as coordinate variables. + + Only existing variables can be set as coordinates. Missing variables + will be silently ignored. drop_variables: str or iterable of str, optional A variable or list of variables to exclude from being parsed from the dataset. This may be useful to drop variables with problems or @@ -691,6 +694,9 @@ def open_dataarray( as coordinate variables. - "all": Set variables referred to in ``'grid_mapping'``, ``'bounds'`` and other attributes as coordinate variables. + + Only existing variables can be set as coordinates. Missing variables + will be silently ignored. drop_variables: str or iterable of str, optional A variable or list of variables to exclude from being parsed from the dataset. This may be useful to drop variables with problems or diff --git a/xarray/conventions.py b/xarray/conventions.py index 596831e270a..cf207f0c37a 100644 --- a/xarray/conventions.py +++ b/xarray/conventions.py @@ -1,9 +1,8 @@ from __future__ import annotations -import warnings from collections import defaultdict from collections.abc import Hashable, Iterable, Mapping, MutableMapping -from typing import TYPE_CHECKING, Any, Union +from typing import TYPE_CHECKING, Any, Literal, Union import numpy as np import pandas as pd @@ -16,6 +15,7 @@ contains_cftime_datetimes, ) from xarray.core.pycompat import is_duck_dask_array +from xarray.core.utils import emit_user_level_warning from xarray.core.variable import IndexVariable, Variable CF_RELATED_DATA = ( @@ -111,13 +111,13 @@ def ensure_dtype_not_object(var: Variable, name: T_Name = None) -> Variable: return var if is_duck_dask_array(data): - warnings.warn( + emit_user_level_warning( f"variable {name} has data in the form of a dask array with " "dtype=object, which means it is being loaded into memory " "to determine a data type that can be safely stored on disk. " "To avoid this, coerce this variable to a fixed-size dtype " "with astype() before saving it.", - SerializationWarning, + category=SerializationWarning, ) data = data.compute() @@ -354,15 +354,14 @@ def _update_bounds_encoding(variables: T_Variables) -> None: and "bounds" in attrs and attrs["bounds"] in variables ): - warnings.warn( - "Variable '{0}' has datetime type and a " - "bounds variable but {0}.encoding does not have " - "units specified. The units encodings for '{0}' " - "and '{1}' will be determined independently " + emit_user_level_warning( + f"Variable {name:s} has datetime type and a " + f"bounds variable but {name:s}.encoding does not have " + f"units specified. The units encodings for {name:s} " + f"and {attrs['bounds']} will be determined independently " "and may not be equal, counter to CF-conventions. " "If this is a concern, specify a units encoding for " - "'{0}' before writing to a file.".format(name, attrs["bounds"]), - UserWarning, + f"{name:s} before writing to a file.", ) if has_date_units and "bounds" in attrs: @@ -379,7 +378,7 @@ def decode_cf_variables( concat_characters: bool = True, mask_and_scale: bool = True, decode_times: bool = True, - decode_coords: bool = True, + decode_coords: bool | Literal["coordinates", "all"] = True, drop_variables: T_DropVariables = None, use_cftime: bool | None = None, decode_timedelta: bool | None = None, @@ -441,11 +440,14 @@ def stackable(dim: Hashable) -> bool: if decode_coords in [True, "coordinates", "all"]: var_attrs = new_vars[k].attrs if "coordinates" in var_attrs: - coord_str = var_attrs["coordinates"] - var_coord_names = coord_str.split() - if all(k in variables for k in var_coord_names): - new_vars[k].encoding["coordinates"] = coord_str - del var_attrs["coordinates"] + var_coord_names = [ + c for c in var_attrs["coordinates"].split() if c in variables + ] + # propagate as is + new_vars[k].encoding["coordinates"] = var_attrs["coordinates"] + del var_attrs["coordinates"] + # but only use as coordinate if existing + if var_coord_names: coord_names.update(var_coord_names) if decode_coords == "all": @@ -461,8 +463,8 @@ def stackable(dim: Hashable) -> bool: for role_or_name in part.split() ] if len(roles_and_names) % 2 == 1: - warnings.warn( - f"Attribute {attr_name:s} malformed", stacklevel=5 + emit_user_level_warning( + f"Attribute {attr_name:s} malformed" ) var_names = roles_and_names[1::2] if all(var_name in variables for var_name in var_names): @@ -474,9 +476,8 @@ def stackable(dim: Hashable) -> bool: for proj_name in var_names if proj_name not in variables ] - warnings.warn( + emit_user_level_warning( f"Variable(s) referenced in {attr_name:s} not in variables: {referenced_vars_not_in_variables!s}", - stacklevel=5, ) del var_attrs[attr_name] @@ -493,7 +494,7 @@ def decode_cf( concat_characters: bool = True, mask_and_scale: bool = True, decode_times: bool = True, - decode_coords: bool = True, + decode_coords: bool | Literal["coordinates", "all"] = True, drop_variables: T_DropVariables = None, use_cftime: bool | None = None, decode_timedelta: bool | None = None, @@ -632,12 +633,11 @@ def _encode_coordinates(variables, attributes, non_dim_coord_names): for name in list(non_dim_coord_names): if isinstance(name, str) and " " in name: - warnings.warn( + emit_user_level_warning( f"coordinate {name!r} has a space in its name, which means it " "cannot be marked as a coordinate on disk and will be " "saved as a data variable instead", - SerializationWarning, - stacklevel=6, + category=SerializationWarning, ) non_dim_coord_names.discard(name) @@ -710,11 +710,11 @@ def _encode_coordinates(variables, attributes, non_dim_coord_names): if global_coordinates: attributes = dict(attributes) if "coordinates" in attributes: - warnings.warn( + emit_user_level_warning( f"cannot serialize global coordinates {global_coordinates!r} because the global " f"attribute 'coordinates' already exists. This may prevent faithful roundtripping" f"of xarray datasets", - SerializationWarning, + category=SerializationWarning, ) else: attributes["coordinates"] = " ".join(sorted(map(str, global_coordinates))) diff --git a/xarray/tests/test_conventions.py b/xarray/tests/test_conventions.py index 4dae7809be9..5157688b629 100644 --- a/xarray/tests/test_conventions.py +++ b/xarray/tests/test_conventions.py @@ -80,6 +80,28 @@ def test_decode_cf_with_conflicting_fill_missing_value() -> None: assert_identical(actual, expected) +def test_decode_cf_variable_with_mismatched_coordinates() -> None: + # tests for decoding mismatched coordinates attributes + # see GH #1809 + zeros1 = np.zeros((1, 5, 3)) + orig = Dataset( + { + "XLONG": (["x", "y"], zeros1.squeeze(0), {}), + "XLAT": (["x", "y"], zeros1.squeeze(0), {}), + "foo": (["time", "x", "y"], zeros1, {"coordinates": "XTIME XLONG XLAT"}), + "time": ("time", [0.0], {"units": "hours since 2017-01-01"}), + } + ) + decoded = conventions.decode_cf(orig, decode_coords=True) + assert decoded["foo"].encoding["coordinates"] == "XTIME XLONG XLAT" + assert list(decoded.coords.keys()) == ["XLONG", "XLAT", "time"] + + decoded = conventions.decode_cf(orig, decode_coords=False) + assert "coordinates" not in decoded["foo"].encoding + assert decoded["foo"].attrs.get("coordinates") == "XTIME XLONG XLAT" + assert list(decoded.coords.keys()) == ["time"] + + @requires_cftime class TestEncodeCFVariable: def test_incompatible_attributes(self) -> None: @@ -246,9 +268,12 @@ def test_dataset(self) -> None: assert_identical(expected, actual) def test_invalid_coordinates(self) -> None: - # regression test for GH308 + # regression test for GH308, GH1809 original = Dataset({"foo": ("t", [1, 2], {"coordinates": "invalid"})}) + decoded = Dataset({"foo": ("t", [1, 2], {}, {"coordinates": "invalid"})}) actual = conventions.decode_cf(original) + assert_identical(decoded, actual) + actual = conventions.decode_cf(original, decode_coords=False) assert_identical(original, actual) def test_decode_coordinates(self) -> None: From ba7f2d5dc8a6f77c91a3bb9e54b5ca39abcdb939 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kai=20M=C3=BChlbauer?= Date: Tue, 26 Sep 2023 10:12:44 +0200 Subject: [PATCH 17/46] Release 2023.09.0 (#8229) --- doc/whats-new.rst | 36 +++++++++++++++++++++++------------- 1 file changed, 23 insertions(+), 13 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index c37a3213793..2b0d2e151c5 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -14,10 +14,21 @@ What's New np.random.seed(123456) -.. _whats-new.2023.08.1: +.. _whats-new.2023.09.0: -v2023.08.1 (unreleased) ------------------------ +v2023.09.0 (Sep 26, 2023) +------------------------- + +This release continues work on the new :py:class:`xarray.Coordinates` object, allows to provide `preferred_chunks` when +reading from netcdf files, enables :py:func:`xarray.apply_ufunc` to handle missing core dimensions and fixes several bugs. + +Thanks to the 24 contributors to this release: Alexander Fischer, Amrest Chinkamol, Benoit Bovy, Darsh Ranjan, Deepak Cherian, +Gianfranco Costamagna, Gregorio L. Trevisan, Illviljan, Joe Hamman, JR, Justus Magin, Kai Mühlbauer, Kian-Meng Ang, Kyle Sunden, +Martin Raspaud, Mathias Hauser, Mattia Almansi, Maximilian Roos, András Gunyhó, Michael Niklas, Richard Kleijn, Riulinchen, +Tom Nicholas and Wiktor Kraśnicki. + +We welcome the following new contributors to Xarray!: Alexander Fischer, Amrest Chinkamol, Darsh Ranjan, Gianfranco Costamagna, Gregorio L. Trevisan, +Kian-Meng Ang, Riulinchen and Wiktor Kraśnicki. New Features ~~~~~~~~~~~~ @@ -28,11 +39,8 @@ New Features By `Benoît Bovy `_. - Provide `preferred_chunks` for data read from netcdf files (:issue:`1440`, :pull:`7948`). By `Martin Raspaud `_. -- Improved static typing of reduction methods (:pull:`6746`). - By `Richard Kleijn `_. - Added `on_missing_core_dims` to :py:meth:`apply_ufunc` to allow for copying or - dropping a :py:class:`Dataset`'s variables with missing core dimensions. - (:pull:`8138`) + dropping a :py:class:`Dataset`'s variables with missing core dimensions (:pull:`8138`). By `Maximilian Roos `_. Breaking changes @@ -61,6 +69,8 @@ Deprecations Bug fixes ~~~~~~~~~ +- Improved static typing of reduction methods (:pull:`6746`). + By `Richard Kleijn `_. - Fix bug where empty attrs would generate inconsistent tokens (:issue:`6970`, :pull:`8101`). By `Mattia Almansi `_. - Improved handling of multi-coordinate indexes when updating coordinates, including bug fixes @@ -71,8 +81,8 @@ Bug fixes :pull:`8104`). By `Benoît Bovy `_. - Fix bug where :py:class:`DataArray` instances on the right-hand side - of :py:meth:`DataArray.__setitem__` lose dimension names. - (:issue:`7030`, :pull:`8067`) By `Darsh Ranjan `_. + of :py:meth:`DataArray.__setitem__` lose dimension names (:issue:`7030`, :pull:`8067`). + By `Darsh Ranjan `_. - Return ``float64`` in presence of ``NaT`` in :py:class:`~core.accessor_dt.DatetimeAccessor` and special case ``NaT`` handling in :py:meth:`~core.accessor_dt.DatetimeAccessor.isocalendar` (:issue:`7928`, :pull:`8084`). @@ -83,14 +93,13 @@ Bug fixes - Calling plot with kwargs ``col``, ``row`` or ``hue`` no longer squeezes dimensions passed via these arguments (:issue:`7552`, :pull:`8174`). By `Wiktor Kraśnicki `_. -- Fixed a bug where casting from ``float`` to ``int64`` (undefined for ``NaN``) led to varying - issues (:issue:`7817`, :issue:`7942`, :issue:`7790`, :issue:`6191`, :issue:`7096`, +- Fixed a bug where casting from ``float`` to ``int64`` (undefined for ``NaN``) led to varying issues (:issue:`7817`, :issue:`7942`, :issue:`7790`, :issue:`6191`, :issue:`7096`, :issue:`1064`, :pull:`7827`). By `Kai Mühlbauer `_. - Fixed a bug where inaccurate ``coordinates`` silently failed to decode variable (:issue:`1809`, :pull:`8195`). By `Kai Mühlbauer `_ - ``.rolling_exp`` functions no longer mistakenly lose non-dimensioned coords - (:issue:`6528`, :pull:`8114`) + (:issue:`6528`, :pull:`8114`). By `Maximilian Roos `_. - In the event that user-provided datetime64/timedelta64 units and integer dtype encoding parameters conflict with each other, override the units to preserve an integer dtype for most faithful serialization to disk (:issue:`1064`, :pull:`8201`). By `Kai Mühlbauer `_. @@ -101,6 +110,8 @@ Bug fixes Documentation ~~~~~~~~~~~~~ +- Make documentation of :py:meth:`DataArray.where` clearer (:issue:`7767`, :pull:`7955`). + By `Riulinchen `_. Internal Changes ~~~~~~~~~~~~~~~~ @@ -116,7 +127,6 @@ Internal Changes - Test range of fill_value's in test_interpolate_pd_compat (:issue:`8146`, :pull:`8189`). By `Kai Mühlbauer `_. - .. _whats-new.2023.08.0: v2023.08.0 (Aug 18, 2023) From 84f5a0d2eef69cd2ee127c138e61066637d9f6ac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kai=20M=C3=BChlbauer?= Date: Tue, 26 Sep 2023 17:24:49 +0200 Subject: [PATCH 18/46] [skip-ci] dev whats-new (#8232) --- doc/whats-new.rst | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 2b0d2e151c5..17744288aef 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -14,6 +14,35 @@ What's New np.random.seed(123456) +.. _whats-new.2023.09.1: + +v2023.09.1 (unreleased) +----------------------- + +New Features +~~~~~~~~~~~~ + + +Breaking changes +~~~~~~~~~~~~~~~~ + + +Deprecations +~~~~~~~~~~~~ + + +Bug fixes +~~~~~~~~~ + + +Documentation +~~~~~~~~~~~~~ + + +Internal Changes +~~~~~~~~~~~~~~~~ + + .. _whats-new.2023.09.0: v2023.09.0 (Sep 26, 2023) From c3b5ead8cdbce38157265fd449a2a641cc118066 Mon Sep 17 00:00:00 2001 From: Anderson Banihirwe <13301940+andersy005@users.noreply.github.com> Date: Wed, 27 Sep 2023 08:23:20 -0700 Subject: [PATCH 19/46] initial refactor for NamedArray (#8075) * initial prototype for NamedArray * move NDArrayMixin and NdimSizeLenMixin inside named_array * vendor is_duck_dask_array * vendor Frozen object * update import * move _default sentinel value * rename subpackage to namedarray per @TomNicholas suggestion * Remove NdimSizeLenMixin * fix typing * add annotations * Remove NDArrayMixin * Apply suggestions from code review Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com> * fix typing * fix return type * revert NDArrayMixin * [WIP] as_compatible_data refactor * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * duplicate sentinel value and leave the original sentinel object alone * Apply suggestions from code review Co-authored-by: Stephan Hoyer * use DuckArray * Apply suggestions from code review Co-authored-by: Stephan Hoyer * use sentinel value from xarray * remove unused code * fix variable constructor * fix as_compatible_data utility function * move _to_dense and _non_zero to NamedArray * more typing * add initial tests * Apply suggestions from code review Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * attempt to fix some mypy errors * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update core.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update core.py * All input data can be arraylike * Update core.py * Update core.py * get and set attrs at the same level. * data doesn't have to be ndarray * avoid redefining typing use new variable names instead * import on runtime as well to be able to cast * requires ufunc and function to be a valid duck array * Add array_namespace * Update test_dataset.py * Update test_dataset.py * remove Frozen * update tests * update tests * switch to functional API * add fastpath * Test making sizes dict[Hashable, int] * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * A lot of errors... Try Mapping instead * Update groupby.py * Update types.py * Apply suggestions from code review Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com> Co-authored-by: Deepak Cherian * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update docstrings * update error messages * update tests * test explicitly index array * update tests * remove unused types * Update xarray/tests/test_namedarray.py Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * use Self --------- Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com> Co-authored-by: dcherian Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Stephan Hoyer Co-authored-by: Deepak Cherian --- xarray/core/common.py | 2 +- xarray/core/groupby.py | 4 +- xarray/core/types.py | 2 +- xarray/core/variable.py | 297 +-------------------- xarray/namedarray/__init__.py | 0 xarray/namedarray/core.py | 447 ++++++++++++++++++++++++++++++++ xarray/namedarray/utils.py | 68 +++++ xarray/tests/test_dataset.py | 6 +- xarray/tests/test_formatting.py | 2 +- xarray/tests/test_namedarray.py | 165 ++++++++++++ 10 files changed, 700 insertions(+), 293 deletions(-) create mode 100644 xarray/namedarray/__init__.py create mode 100644 xarray/namedarray/core.py create mode 100644 xarray/namedarray/utils.py create mode 100644 xarray/tests/test_namedarray.py diff --git a/xarray/core/common.py b/xarray/core/common.py index e4e3e60e815..db9b2aead23 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -223,7 +223,7 @@ def _get_axis_num(self: Any, dim: Hashable) -> int: raise ValueError(f"{dim!r} not found in array dimensions {self.dims!r}") @property - def sizes(self: Any) -> Frozen[Hashable, int]: + def sizes(self: Any) -> Mapping[Hashable, int]: """Ordered mapping from dimension names to lengths. Immutable. diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 9894a4a4daf..e9ddf044568 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -699,7 +699,7 @@ class GroupBy(Generic[T_Xarray]): _groups: dict[GroupKey, GroupIndex] | None _dims: tuple[Hashable, ...] | Frozen[Hashable, int] | None - _sizes: Frozen[Hashable, int] | None + _sizes: Mapping[Hashable, int] | None def __init__( self, @@ -746,7 +746,7 @@ def __init__( self._sizes = None @property - def sizes(self) -> Frozen[Hashable, int]: + def sizes(self) -> Mapping[Hashable, int]: """Ordered mapping from dimension names to lengths. Immutable. diff --git a/xarray/core/types.py b/xarray/core/types.py index 073121b13b1..bbcda7ca240 100644 --- a/xarray/core/types.py +++ b/xarray/core/types.py @@ -106,7 +106,7 @@ def dims(self) -> Frozen[Hashable, int] | tuple[Hashable, ...]: ... @property - def sizes(self) -> Frozen[Hashable, int]: + def sizes(self) -> Mapping[Hashable, int]: ... @property diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 2571b093450..0e6e45d4929 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -26,10 +26,7 @@ as_indexable, ) from xarray.core.options import OPTIONS, _get_keep_attrs -from xarray.core.parallelcompat import ( - get_chunked_array_type, - guess_chunkmanager, -) +from xarray.core.parallelcompat import get_chunked_array_type, guess_chunkmanager from xarray.core.pycompat import ( array_type, integer_types, @@ -38,8 +35,6 @@ is_duck_dask_array, ) from xarray.core.utils import ( - Frozen, - NdimSizeLenMixin, OrderedSet, _default, decode_numpy_dict_values, @@ -50,6 +45,7 @@ is_duck_array, maybe_coerce_to_str, ) +from xarray.namedarray.core import NamedArray NON_NUMPY_SUPPORTED_ARRAY_TYPES = ( indexing.ExplicitlyIndexed, @@ -268,7 +264,7 @@ def as_compatible_data( data = np.timedelta64(getattr(data, "value", data), "ns") # we don't want nested self-described arrays - if isinstance(data, (pd.Series, pd.Index, pd.DataFrame)): + if isinstance(data, (pd.Series, pd.DataFrame)): data = data.values if isinstance(data, np.ma.MaskedArray): @@ -315,7 +311,7 @@ def _as_array_or_item(data): return data -class Variable(AbstractArray, NdimSizeLenMixin, VariableArithmetic): +class Variable(NamedArray, AbstractArray, VariableArithmetic): """A netcdf-like variable consisting of dimensions, data and attributes which describe a single Array. A single Variable object is not fully described outside the context of its parent Dataset (if you want such a @@ -365,51 +361,14 @@ def __init__( Well-behaved code to serialize a Variable should ignore unrecognized encoding items. """ - self._data: T_DuckArray = as_compatible_data(data, fastpath=fastpath) - self._dims = self._parse_dimensions(dims) - self._attrs: dict[Any, Any] | None = None + super().__init__( + dims=dims, data=as_compatible_data(data, fastpath=fastpath), attrs=attrs + ) + self._encoding = None - if attrs is not None: - self.attrs = attrs if encoding is not None: self.encoding = encoding - @property - def dtype(self) -> np.dtype: - """ - Data-type of the array’s elements. - - See Also - -------- - ndarray.dtype - numpy.dtype - """ - return self._data.dtype - - @property - def shape(self) -> tuple[int, ...]: - """ - Tuple of array dimensions. - - See Also - -------- - numpy.ndarray.shape - """ - return self._data.shape - - @property - def nbytes(self) -> int: - """ - Total bytes consumed by the elements of the data array. - - If the underlying data array does not include ``nbytes``, estimates - the bytes consumed based on the ``size`` and ``dtype``. - """ - if hasattr(self._data, "nbytes"): - return self._data.nbytes - else: - return self.size * self.dtype.itemsize - @property def _in_memory(self): return isinstance( @@ -441,11 +400,7 @@ def data(self): @data.setter def data(self, data: T_DuckArray | ArrayLike) -> None: data = as_compatible_data(data) - if data.shape != self.shape: # type: ignore[attr-defined] - raise ValueError( - f"replacement data must match the Variable's shape. " - f"replacement data has shape {data.shape}; Variable has shape {self.shape}" # type: ignore[attr-defined] - ) + self._check_shape(data) self._data = data def astype( @@ -571,41 +526,6 @@ def compute(self, **kwargs): new = self.copy(deep=False) return new.load(**kwargs) - def __dask_tokenize__(self): - # Use v.data, instead of v._data, in order to cope with the wrappers - # around NetCDF and the like - from dask.base import normalize_token - - return normalize_token((type(self), self._dims, self.data, self.attrs)) - - def __dask_graph__(self): - if is_duck_dask_array(self._data): - return self._data.__dask_graph__() - else: - return None - - def __dask_keys__(self): - return self._data.__dask_keys__() - - def __dask_layers__(self): - return self._data.__dask_layers__() - - @property - def __dask_optimize__(self): - return self._data.__dask_optimize__ - - @property - def __dask_scheduler__(self): - return self._data.__dask_scheduler__ - - def __dask_postcompute__(self): - array_func, array_args = self._data.__dask_postcompute__() - return self._dask_finalize, (array_func,) + array_args - - def __dask_postpersist__(self): - array_func, array_args = self._data.__dask_postpersist__() - return self._dask_finalize, (array_func,) + array_args - def _dask_finalize(self, results, array_func, *args, **kwargs): data = array_func(results, *args, **kwargs) return Variable(self._dims, data, attrs=self._attrs, encoding=self._encoding) @@ -667,27 +587,6 @@ def to_dict( return item - @property - def dims(self) -> tuple[Hashable, ...]: - """Tuple of dimension names with which this variable is associated.""" - return self._dims - - @dims.setter - def dims(self, value: str | Iterable[Hashable]) -> None: - self._dims = self._parse_dimensions(value) - - def _parse_dimensions(self, dims: str | Iterable[Hashable]) -> tuple[Hashable, ...]: - if isinstance(dims, str): - dims = (dims,) - else: - dims = tuple(dims) - if len(dims) != self.ndim: - raise ValueError( - f"dimensions {dims} must have the same length as the " - f"number of data dimensions, ndim={self.ndim}" - ) - return dims - def _item_key_to_tuple(self, key): if utils.is_dict_like(key): return tuple(key.get(dim, slice(None)) for dim in self.dims) @@ -820,13 +719,6 @@ def _broadcast_indexes_outer(self, key): return dims, OuterIndexer(tuple(new_key)), None - def _nonzero(self): - """Equivalent numpy's nonzero but returns a tuple of Variables.""" - # TODO we should replace dask's native nonzero - # after https://github.com/dask/dask/issues/1076 is implemented. - nonzeros = np.nonzero(self.data) - return tuple(Variable((dim), nz) for nz, dim in zip(nonzeros, self.dims)) - def _broadcast_indexes_vectorized(self, key): variables = [] out_dims_set = OrderedSet() @@ -976,17 +868,6 @@ def __setitem__(self, key, value): indexable = as_indexable(self._data) indexable[index_tuple] = value - @property - def attrs(self) -> dict[Any, Any]: - """Dictionary of local attributes on this variable.""" - if self._attrs is None: - self._attrs = {} - return self._attrs - - @attrs.setter - def attrs(self, value: Mapping[Any, Any]) -> None: - self._attrs = dict(value) - @property def encoding(self) -> dict[Any, Any]: """Dictionary of encodings on this variable.""" @@ -1005,66 +886,6 @@ def reset_encoding(self) -> Self: """Return a new Variable without encoding.""" return self._replace(encoding={}) - def copy( - self, deep: bool = True, data: T_DuckArray | ArrayLike | None = None - ) -> Self: - """Returns a copy of this object. - - If `deep=True`, the data array is loaded into memory and copied onto - the new object. Dimensions, attributes and encodings are always copied. - - Use `data` to create a new object with the same structure as - original but entirely new data. - - Parameters - ---------- - deep : bool, default: True - Whether the data array is loaded into memory and copied onto - the new object. Default is True. - data : array_like, optional - Data to use in the new object. Must have same shape as original. - When `data` is used, `deep` is ignored. - - Returns - ------- - object : Variable - New object with dimensions, attributes, encodings, and optionally - data copied from original. - - Examples - -------- - Shallow copy versus deep copy - - >>> var = xr.Variable(data=[1, 2, 3], dims="x") - >>> var.copy() - - array([1, 2, 3]) - >>> var_0 = var.copy(deep=False) - >>> var_0[0] = 7 - >>> var_0 - - array([7, 2, 3]) - >>> var - - array([7, 2, 3]) - - Changing the data using the ``data`` argument maintains the - structure of the original object, but with the new data. Original - object is unaffected. - - >>> var.copy(data=[0.1, 0.2, 0.3]) - - array([0.1, 0.2, 0.3]) - >>> var - - array([7, 2, 3]) - - See Also - -------- - pandas.DataFrame.copy - """ - return self._copy(deep=deep, data=data) - def _copy( self, deep: bool = True, @@ -1111,57 +932,11 @@ def _replace( data = copy.copy(self.data) if attrs is _default: attrs = copy.copy(self._attrs) + if encoding is _default: encoding = copy.copy(self._encoding) return type(self)(dims, data, attrs, encoding, fastpath=True) - def __copy__(self) -> Self: - return self._copy(deep=False) - - def __deepcopy__(self, memo: dict[int, Any] | None = None) -> Self: - return self._copy(deep=True, memo=memo) - - # mutable objects should not be hashable - # https://github.com/python/mypy/issues/4266 - __hash__ = None # type: ignore[assignment] - - @property - def chunks(self) -> tuple[tuple[int, ...], ...] | None: - """ - Tuple of block lengths for this dataarray's data, in order of dimensions, or None if - the underlying data is not a dask array. - - See Also - -------- - Variable.chunk - Variable.chunksizes - xarray.unify_chunks - """ - return getattr(self._data, "chunks", None) - - @property - def chunksizes(self) -> Mapping[Any, tuple[int, ...]]: - """ - Mapping from dimension names to block lengths for this variable's data, or None if - the underlying data is not a dask array. - Cannot be modified directly, but can be modified by calling .chunk(). - - Differs from variable.chunks because it returns a mapping of dimensions to chunk shapes - instead of a tuple of chunk shapes. - - See Also - -------- - Variable.chunk - Variable.chunks - xarray.unify_chunks - """ - if hasattr(self._data, "chunks"): - return Frozen({dim: c for dim, c in zip(self.dims, self.data.chunks)}) - else: - return {} - - _array_counter = itertools.count() - def chunk( self, chunks: ( @@ -1312,36 +1087,6 @@ def as_numpy(self) -> Self: """Coerces wrapped data into a numpy array, returning a Variable.""" return self._replace(data=self.to_numpy()) - def _as_sparse(self, sparse_format=_default, fill_value=dtypes.NA): - """ - use sparse-array as backend. - """ - import sparse - - # TODO: what to do if dask-backended? - if fill_value is dtypes.NA: - dtype, fill_value = dtypes.maybe_promote(self.dtype) - else: - dtype = dtypes.result_type(self.dtype, fill_value) - - if sparse_format is _default: - sparse_format = "coo" - try: - as_sparse = getattr(sparse, f"as_{sparse_format.lower()}") - except AttributeError: - raise ValueError(f"{sparse_format} is not a valid sparse format") - - data = as_sparse(self.data.astype(dtype), fill_value=fill_value) - return self._replace(data=data) - - def _to_dense(self): - """ - Change backend from sparse to np.array - """ - if hasattr(self._data, "todense"): - return self._replace(data=self._data.todense()) - return self.copy(deep=False) - def isel( self, indexers: Mapping[Any, Any] | None = None, @@ -2649,28 +2394,6 @@ def notnull(self, keep_attrs: bool | None = None): keep_attrs=keep_attrs, ) - @property - def real(self): - """ - The real part of the variable. - - See Also - -------- - numpy.ndarray.real - """ - return self._replace(data=self.data.real) - - @property - def imag(self): - """ - The imaginary part of the variable. - - See Also - -------- - numpy.ndarray.imag - """ - return self._replace(data=self.data.imag) - def __array_wrap__(self, obj, context=None): return Variable(self.dims, obj) diff --git a/xarray/namedarray/__init__.py b/xarray/namedarray/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py new file mode 100644 index 00000000000..16a7b422f1b --- /dev/null +++ b/xarray/namedarray/core.py @@ -0,0 +1,447 @@ +from __future__ import annotations + +import copy +import math +import sys +import typing +from collections.abc import Hashable, Iterable, Mapping + +import numpy as np + +# TODO: get rid of this after migrating this class to array API +from xarray.core import dtypes +from xarray.core.indexing import ExplicitlyIndexed +from xarray.core.utils import Default, _default +from xarray.namedarray.utils import ( + T_DuckArray, + is_duck_array, + is_duck_dask_array, + to_0d_object_array, +) + +if typing.TYPE_CHECKING: + T_NamedArray = typing.TypeVar("T_NamedArray", bound="NamedArray") + DimsInput = typing.Union[str, Iterable[Hashable]] + Dims = tuple[Hashable, ...] + + +try: + if sys.version_info >= (3, 11): + from typing import Self + else: + from typing_extensions import Self +except ImportError: + if typing.TYPE_CHECKING: + raise + else: + Self: typing.Any = None + + +# TODO: Add tests! +def as_compatible_data( + data: T_DuckArray | np.typing.ArrayLike, fastpath: bool = False +) -> T_DuckArray: + if fastpath and getattr(data, "ndim", 0) > 0: + # can't use fastpath (yet) for scalars + return typing.cast(T_DuckArray, data) + + if isinstance(data, np.ma.MaskedArray): + mask = np.ma.getmaskarray(data) + if mask.any(): + # TODO: requires refactoring/vendoring xarray.core.dtypes and xarray.core.duck_array_ops + raise NotImplementedError("MaskedArray is not supported yet") + else: + return typing.cast(T_DuckArray, np.asarray(data)) + if is_duck_array(data): + return data + if isinstance(data, NamedArray): + return typing.cast(T_DuckArray, data.data) + + if isinstance(data, ExplicitlyIndexed): + # TODO: better that is_duck_array(ExplicitlyIndexed) -> True + return typing.cast(T_DuckArray, data) + + if isinstance(data, tuple): + data = to_0d_object_array(data) + + # validate whether the data is valid data types. + return typing.cast(T_DuckArray, np.asarray(data)) + + +class NamedArray: + + """A lightweight wrapper around duck arrays with named dimensions and attributes which describe a single Array. + Numeric operations on this object implement array broadcasting and dimension alignment based on dimension names, + rather than axis order.""" + + __slots__ = ("_dims", "_data", "_attrs") + + def __init__( + self, + dims: DimsInput, + data: T_DuckArray | np.typing.ArrayLike, + attrs: dict | None = None, + fastpath: bool = False, + ): + """ + Parameters + ---------- + dims : str or iterable of str + Name(s) of the dimension(s). + data : T_DuckArray or np.typing.ArrayLike + The actual data that populates the array. Should match the shape specified by `dims`. + attrs : dict, optional + A dictionary containing any additional information or attributes you want to store with the array. + Default is None, meaning no attributes will be stored. + fastpath : bool, optional + A flag to indicate if certain validations should be skipped for performance reasons. + Should only be True if you are certain about the integrity of the input data. + Default is False. + + Raises + ------ + ValueError + If the `dims` length does not match the number of data dimensions (ndim). + + + """ + self._data: T_DuckArray = as_compatible_data(data, fastpath=fastpath) + self._dims: Dims = self._parse_dimensions(dims) + self._attrs: dict | None = dict(attrs) if attrs else None + + @property + def ndim(self) -> int: + """ + Number of array dimensions. + + See Also + -------- + numpy.ndarray.ndim + """ + return len(self.shape) + + @property + def size(self) -> int: + """ + Number of elements in the array. + + Equal to ``np.prod(a.shape)``, i.e., the product of the array’s dimensions. + + See Also + -------- + numpy.ndarray.size + """ + return math.prod(self.shape) + + def __len__(self) -> int: + try: + return self.shape[0] + except Exception as exc: + raise TypeError("len() of unsized object") from exc + + @property + def dtype(self) -> np.dtype: + """ + Data-type of the array’s elements. + + See Also + -------- + ndarray.dtype + numpy.dtype + """ + return self._data.dtype + + @property + def shape(self) -> tuple[int, ...]: + """ + + + Returns + ------- + shape : tuple of ints + Tuple of array dimensions. + + + + See Also + -------- + numpy.ndarray.shape + """ + return self._data.shape + + @property + def nbytes(self) -> int: + """ + Total bytes consumed by the elements of the data array. + + If the underlying data array does not include ``nbytes``, estimates + the bytes consumed based on the ``size`` and ``dtype``. + """ + if hasattr(self._data, "nbytes"): + return self._data.nbytes + else: + return self.size * self.dtype.itemsize + + @property + def dims(self) -> Dims: + """Tuple of dimension names with which this NamedArray is associated.""" + return self._dims + + @dims.setter + def dims(self, value: DimsInput) -> None: + self._dims = self._parse_dimensions(value) + + def _parse_dimensions(self, dims: DimsInput) -> Dims: + dims = (dims,) if isinstance(dims, str) else tuple(dims) + if len(dims) != self.ndim: + raise ValueError( + f"dimensions {dims} must have the same length as the " + f"number of data dimensions, ndim={self.ndim}" + ) + return dims + + @property + def attrs(self) -> dict[typing.Any, typing.Any]: + """Dictionary of local attributes on this NamedArray.""" + if self._attrs is None: + self._attrs = {} + return self._attrs + + @attrs.setter + def attrs(self, value: Mapping) -> None: + self._attrs = dict(value) + + def _check_shape(self, new_data: T_DuckArray) -> None: + if new_data.shape != self.shape: + raise ValueError( + f"replacement data must match the {self.__class__.__name__}'s shape. " + f"replacement data has shape {new_data.shape}; {self.__class__.__name__} has shape {self.shape}" + ) + + @property + def data(self): + """ + The NamedArray's data as an array. The underlying array type + (e.g. dask, sparse, pint) is preserved. + + """ + + return self._data + + @data.setter + def data(self, data: T_DuckArray | np.typing.ArrayLike) -> None: + data = as_compatible_data(data) + self._check_shape(data) + self._data = data + + @property + def real(self) -> Self: + """ + The real part of the NamedArray. + + See Also + -------- + numpy.ndarray.real + """ + return self._replace(data=self.data.real) + + @property + def imag(self) -> Self: + """ + The imaginary part of the NamedArray. + + See Also + -------- + numpy.ndarray.imag + """ + return self._replace(data=self.data.imag) + + def __dask_tokenize__(self): + # Use v.data, instead of v._data, in order to cope with the wrappers + # around NetCDF and the like + from dask.base import normalize_token + + return normalize_token((type(self), self._dims, self.data, self.attrs)) + + def __dask_graph__(self): + return self._data.__dask_graph__() if is_duck_dask_array(self._data) else None + + def __dask_keys__(self): + return self._data.__dask_keys__() + + def __dask_layers__(self): + return self._data.__dask_layers__() + + @property + def __dask_optimize__(self) -> typing.Callable: + return self._data.__dask_optimize__ + + @property + def __dask_scheduler__(self) -> typing.Callable: + return self._data.__dask_scheduler__ + + def __dask_postcompute__( + self, + ) -> tuple[typing.Callable, tuple[typing.Any, ...]]: + array_func, array_args = self._data.__dask_postcompute__() + return self._dask_finalize, (array_func,) + array_args + + def __dask_postpersist__( + self, + ) -> tuple[typing.Callable, tuple[typing.Any, ...]]: + array_func, array_args = self._data.__dask_postpersist__() + return self._dask_finalize, (array_func,) + array_args + + def _dask_finalize(self, results, array_func, *args, **kwargs) -> Self: + data = array_func(results, *args, **kwargs) + return type(self)(self._dims, data, attrs=self._attrs) + + @property + def chunks(self) -> tuple[tuple[int, ...], ...] | None: + """ + Tuple of block lengths for this NamedArray's data, in order of dimensions, or None if + the underlying data is not a dask array. + + See Also + -------- + NamedArray.chunk + NamedArray.chunksizes + xarray.unify_chunks + """ + return getattr(self._data, "chunks", None) + + @property + def chunksizes( + self, + ) -> typing.Mapping[typing.Any, tuple[int, ...]]: + """ + Mapping from dimension names to block lengths for this namedArray's data, or None if + the underlying data is not a dask array. + Cannot be modified directly, but can be modified by calling .chunk(). + + Differs from NamedArray.chunks because it returns a mapping of dimensions to chunk shapes + instead of a tuple of chunk shapes. + + See Also + -------- + NamedArray.chunk + NamedArray.chunks + xarray.unify_chunks + """ + if hasattr(self._data, "chunks"): + return dict(zip(self.dims, self.data.chunks)) + else: + return {} + + @property + def sizes(self) -> dict[Hashable, int]: + """Ordered mapping from dimension names to lengths.""" + return dict(zip(self.dims, self.shape)) + + def _replace(self, dims=_default, data=_default, attrs=_default) -> Self: + if dims is _default: + dims = copy.copy(self._dims) + if data is _default: + data = copy.copy(self._data) + if attrs is _default: + attrs = copy.copy(self._attrs) + return type(self)(dims, data, attrs) + + def _copy( + self, + deep: bool = True, + data: T_DuckArray | np.typing.ArrayLike | None = None, + memo: dict[int, typing.Any] | None = None, + ) -> Self: + if data is None: + ndata = self._data + if deep: + ndata = copy.deepcopy(ndata, memo=memo) + else: + ndata = as_compatible_data(data) + self._check_shape(ndata) + + attrs = ( + copy.deepcopy(self._attrs, memo=memo) if deep else copy.copy(self._attrs) + ) + + return self._replace(data=ndata, attrs=attrs) + + def __copy__(self) -> Self: + return self._copy(deep=False) + + def __deepcopy__(self, memo: dict[int, typing.Any] | None = None) -> Self: + return self._copy(deep=True, memo=memo) + + def copy( + self, + deep: bool = True, + data: T_DuckArray | np.typing.ArrayLike | None = None, + ) -> Self: + """Returns a copy of this object. + + If `deep=True`, the data array is loaded into memory and copied onto + the new object. Dimensions, attributes and encodings are always copied. + + Use `data` to create a new object with the same structure as + original but entirely new data. + + Parameters + ---------- + deep : bool, default: True + Whether the data array is loaded into memory and copied onto + the new object. Default is True. + data : array_like, optional + Data to use in the new object. Must have same shape as original. + When `data` is used, `deep` is ignored. + + Returns + ------- + object : NamedArray + New object with dimensions, attributes, and optionally + data copied from original. + + + """ + return self._copy(deep=deep, data=data) + + def _nonzero(self) -> tuple[Self, ...]: + """Equivalent numpy's nonzero but returns a tuple of NamedArrays.""" + # TODO we should replace dask's native nonzero + # after https://github.com/dask/dask/issues/1076 is implemented. + nonzeros = np.nonzero(self.data) + return tuple(type(self)((dim,), nz) for nz, dim in zip(nonzeros, self.dims)) + + def _as_sparse( + self, + sparse_format: str | Default = _default, + fill_value=dtypes.NA, + ) -> Self: + """ + use sparse-array as backend. + """ + import sparse + + # TODO: what to do if dask-backended? + if fill_value is dtypes.NA: + dtype, fill_value = dtypes.maybe_promote(self.dtype) + else: + dtype = dtypes.result_type(self.dtype, fill_value) + + if sparse_format is _default: + sparse_format = "coo" + try: + as_sparse = getattr(sparse, f"as_{sparse_format.lower()}") + except AttributeError as exc: + raise ValueError(f"{sparse_format} is not a valid sparse format") from exc + + data = as_sparse(self.data.astype(dtype), fill_value=fill_value) + return self._replace(data=data) + + def _to_dense(self) -> Self: + """ + Change backend from sparse to np.array + """ + if hasattr(self._data, "todense"): + return self._replace(data=self._data.todense()) + return self.copy(deep=False) diff --git a/xarray/namedarray/utils.py b/xarray/namedarray/utils.py new file mode 100644 index 00000000000..1495e111d85 --- /dev/null +++ b/xarray/namedarray/utils.py @@ -0,0 +1,68 @@ +from __future__ import annotations + +import importlib +import sys +import typing + +import numpy as np + +if typing.TYPE_CHECKING: + if sys.version_info >= (3, 10): + from typing import TypeGuard + else: + from typing_extensions import TypeGuard + +# temporary placeholder for indicating an array api compliant type. +# hopefully in the future we can narrow this down more +T_DuckArray = typing.TypeVar("T_DuckArray", bound=typing.Any) + + +def module_available(module: str) -> bool: + """Checks whether a module is installed without importing it. + + Use this for a lightweight check and lazy imports. + + Parameters + ---------- + module : str + Name of the module. + + Returns + ------- + available : bool + Whether the module is installed. + """ + return importlib.util.find_spec(module) is not None + + +def is_dask_collection(x: typing.Any) -> bool: + if module_available("dask"): + from dask.base import is_dask_collection + + return is_dask_collection(x) + return False + + +def is_duck_array(value: typing.Any) -> TypeGuard[T_DuckArray]: + if isinstance(value, np.ndarray): + return True + return ( + hasattr(value, "ndim") + and hasattr(value, "shape") + and hasattr(value, "dtype") + and ( + (hasattr(value, "__array_function__") and hasattr(value, "__array_ufunc__")) + or hasattr(value, "__array_namespace__") + ) + ) + + +def is_duck_dask_array(x: typing.Any) -> bool: + return is_duck_array(x) and is_dask_collection(x) + + +def to_0d_object_array(value: typing.Any) -> np.ndarray: + """Given a value, wrap it in a 0-D numpy.ndarray with dtype=object.""" + result = np.empty((), dtype=object) + result[()] = value + return result diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 3fb29e01ebb..ac641c4abc3 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -411,10 +411,14 @@ def test_repr_nep18(self) -> None: class Array: def __init__(self): self.shape = (2,) + self.ndim = 1 self.dtype = np.dtype(np.float64) def __array_function__(self, *args, **kwargs): - pass + return NotImplemented + + def __array_ufunc__(self, *args, **kwargs): + return NotImplemented def __repr__(self): return "Custom\nArray" diff --git a/xarray/tests/test_formatting.py b/xarray/tests/test_formatting.py index 7670b77322c..5ca134503e8 100644 --- a/xarray/tests/test_formatting.py +++ b/xarray/tests/test_formatting.py @@ -549,7 +549,7 @@ def _repr_inline_(self, width): return formatted - def __array_function__(self, *args, **kwargs): + def __array_namespace__(self, *args, **kwargs): return NotImplemented @property diff --git a/xarray/tests/test_namedarray.py b/xarray/tests/test_namedarray.py new file mode 100644 index 00000000000..0871a0c6fb9 --- /dev/null +++ b/xarray/tests/test_namedarray.py @@ -0,0 +1,165 @@ +import numpy as np +import pytest + +import xarray as xr +from xarray.namedarray.core import NamedArray, as_compatible_data +from xarray.namedarray.utils import T_DuckArray + + +@pytest.fixture +def random_inputs() -> np.ndarray: + return np.arange(3 * 4 * 5, dtype=np.float32).reshape((3, 4, 5)) + + +@pytest.mark.parametrize( + "input_data, expected_output", + [ + ([1, 2, 3], np.array([1, 2, 3])), + (np.array([4, 5, 6]), np.array([4, 5, 6])), + (NamedArray("time", np.array([1, 2, 3])), np.array([1, 2, 3])), + (2, np.array(2)), + ], +) +def test_as_compatible_data( + input_data: T_DuckArray, expected_output: T_DuckArray +) -> None: + output: T_DuckArray = as_compatible_data(input_data) + assert np.array_equal(output, expected_output) + + +def test_as_compatible_data_with_masked_array() -> None: + masked_array = np.ma.array([1, 2, 3], mask=[False, True, False]) + with pytest.raises(NotImplementedError): + as_compatible_data(masked_array) + + +def test_as_compatible_data_with_0d_object() -> None: + data = np.empty((), dtype=object) + data[()] = (10, 12, 12) + np.array_equal(as_compatible_data(data), data) + + +def test_as_compatible_data_with_explicitly_indexed(random_inputs) -> None: + # TODO: Make xr.core.indexing.ExplicitlyIndexed pass is_duck_array and remove this test. + class CustomArray(xr.core.indexing.NDArrayMixin): + def __init__(self, array): + self.array = array + + class CustomArrayIndexable(CustomArray, xr.core.indexing.ExplicitlyIndexed): + pass + + array = CustomArray(random_inputs) + output = as_compatible_data(array) + assert isinstance(output, np.ndarray) + + array = CustomArrayIndexable(random_inputs) + output = as_compatible_data(array) + assert isinstance(output, CustomArrayIndexable) + + +def test_properties() -> None: + data = 0.5 * np.arange(10).reshape(2, 5) + named_array = NamedArray(["x", "y"], data, {"key": "value"}) + assert named_array.dims == ("x", "y") + assert np.array_equal(named_array.data, data) + assert named_array.attrs == {"key": "value"} + assert named_array.ndim == 2 + assert named_array.sizes == {"x": 2, "y": 5} + assert named_array.size == 10 + assert named_array.nbytes == 80 + assert len(named_array) == 2 + + +def test_attrs() -> None: + named_array = NamedArray(["x", "y"], np.arange(10).reshape(2, 5)) + assert named_array.attrs == {} + named_array.attrs["key"] = "value" + assert named_array.attrs == {"key": "value"} + named_array.attrs = {"key": "value2"} + assert named_array.attrs == {"key": "value2"} + + +def test_data(random_inputs) -> None: + named_array = NamedArray(["x", "y", "z"], random_inputs) + assert np.array_equal(named_array.data, random_inputs) + with pytest.raises(ValueError): + named_array.data = np.random.random((3, 4)).astype(np.float64) + + +# Additional tests as per your original class-based code +@pytest.mark.parametrize( + "data, dtype", + [ + ("foo", np.dtype("U3")), + (np.bytes_("foo"), np.dtype("S3")), + ], +) +def test_0d_string(data, dtype: np.typing.DTypeLike) -> None: + named_array = NamedArray([], data) + assert named_array.data == data + assert named_array.dims == () + assert named_array.sizes == {} + assert named_array.attrs == {} + assert named_array.ndim == 0 + assert named_array.size == 1 + assert named_array.dtype == dtype + + +def test_0d_object() -> None: + named_array = NamedArray([], (10, 12, 12)) + expected_data = np.empty((), dtype=object) + expected_data[()] = (10, 12, 12) + assert np.array_equal(named_array.data, expected_data) + + assert named_array.dims == () + assert named_array.sizes == {} + assert named_array.attrs == {} + assert named_array.ndim == 0 + assert named_array.size == 1 + assert named_array.dtype == np.dtype("O") + + +def test_0d_datetime() -> None: + named_array = NamedArray([], np.datetime64("2000-01-01")) + assert named_array.dtype == np.dtype("datetime64[D]") + + +@pytest.mark.parametrize( + "timedelta, expected_dtype", + [ + (np.timedelta64(1, "D"), np.dtype("timedelta64[D]")), + (np.timedelta64(1, "s"), np.dtype("timedelta64[s]")), + (np.timedelta64(1, "m"), np.dtype("timedelta64[m]")), + (np.timedelta64(1, "h"), np.dtype("timedelta64[h]")), + (np.timedelta64(1, "us"), np.dtype("timedelta64[us]")), + (np.timedelta64(1, "ns"), np.dtype("timedelta64[ns]")), + (np.timedelta64(1, "ps"), np.dtype("timedelta64[ps]")), + (np.timedelta64(1, "fs"), np.dtype("timedelta64[fs]")), + (np.timedelta64(1, "as"), np.dtype("timedelta64[as]")), + ], +) +def test_0d_timedelta(timedelta, expected_dtype: np.dtype) -> None: + named_array = NamedArray([], timedelta) + assert named_array.dtype == expected_dtype + assert named_array.data == timedelta + + +@pytest.mark.parametrize( + "dims, data_shape, new_dims, raises", + [ + (["x", "y", "z"], (2, 3, 4), ["a", "b", "c"], False), + (["x", "y", "z"], (2, 3, 4), ["a", "b"], True), + (["x", "y", "z"], (2, 4, 5), ["a", "b", "c", "d"], True), + ([], [], (), False), + ([], [], ("x",), True), + ], +) +def test_dims_setter(dims, data_shape, new_dims, raises: bool) -> None: + named_array = NamedArray(dims, np.random.random(data_shape)) + assert named_array.dims == tuple(dims) + if raises: + with pytest.raises(ValueError): + named_array.dims = new_dims + else: + named_array.dims = new_dims + assert named_array.dims == tuple(new_dims) From 639ce0fd427545fedd3734fd269fb8b01804beb7 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Thu, 28 Sep 2023 18:18:26 +0200 Subject: [PATCH 20/46] Bind T_DuckArray to NamedArray (#8240) * Bind T_DuckArray to NamedArray * Fix tests --- xarray/namedarray/core.py | 4 ++-- xarray/tests/test_namedarray.py | 18 ++++++++++-------- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index 16a7b422f1b..03bfa16682d 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -68,7 +68,7 @@ def as_compatible_data( return typing.cast(T_DuckArray, np.asarray(data)) -class NamedArray: +class NamedArray(typing.Generic[T_DuckArray]): """A lightweight wrapper around duck arrays with named dimensions and attributes which describe a single Array. Numeric operations on this object implement array broadcasting and dimension alignment based on dimension names, @@ -219,7 +219,7 @@ def _check_shape(self, new_data: T_DuckArray) -> None: ) @property - def data(self): + def data(self) -> T_DuckArray: """ The NamedArray's data as an array. The underlying array type (e.g. dask, sparse, pint) is preserved. diff --git a/xarray/tests/test_namedarray.py b/xarray/tests/test_namedarray.py index 0871a0c6fb9..9d37a6c794c 100644 --- a/xarray/tests/test_namedarray.py +++ b/xarray/tests/test_namedarray.py @@ -59,7 +59,7 @@ class CustomArrayIndexable(CustomArray, xr.core.indexing.ExplicitlyIndexed): def test_properties() -> None: data = 0.5 * np.arange(10).reshape(2, 5) - named_array = NamedArray(["x", "y"], data, {"key": "value"}) + named_array: NamedArray[np.ndarray] = NamedArray(["x", "y"], data, {"key": "value"}) assert named_array.dims == ("x", "y") assert np.array_equal(named_array.data, data) assert named_array.attrs == {"key": "value"} @@ -71,7 +71,9 @@ def test_properties() -> None: def test_attrs() -> None: - named_array = NamedArray(["x", "y"], np.arange(10).reshape(2, 5)) + named_array: NamedArray[np.ndarray] = NamedArray( + ["x", "y"], np.arange(10).reshape(2, 5) + ) assert named_array.attrs == {} named_array.attrs["key"] = "value" assert named_array.attrs == {"key": "value"} @@ -80,7 +82,7 @@ def test_attrs() -> None: def test_data(random_inputs) -> None: - named_array = NamedArray(["x", "y", "z"], random_inputs) + named_array: NamedArray[np.ndarray] = NamedArray(["x", "y", "z"], random_inputs) assert np.array_equal(named_array.data, random_inputs) with pytest.raises(ValueError): named_array.data = np.random.random((3, 4)).astype(np.float64) @@ -95,7 +97,7 @@ def test_data(random_inputs) -> None: ], ) def test_0d_string(data, dtype: np.typing.DTypeLike) -> None: - named_array = NamedArray([], data) + named_array: NamedArray[np.ndarray] = NamedArray([], data) assert named_array.data == data assert named_array.dims == () assert named_array.sizes == {} @@ -106,7 +108,7 @@ def test_0d_string(data, dtype: np.typing.DTypeLike) -> None: def test_0d_object() -> None: - named_array = NamedArray([], (10, 12, 12)) + named_array: NamedArray[np.ndarray] = NamedArray([], (10, 12, 12)) expected_data = np.empty((), dtype=object) expected_data[()] = (10, 12, 12) assert np.array_equal(named_array.data, expected_data) @@ -120,7 +122,7 @@ def test_0d_object() -> None: def test_0d_datetime() -> None: - named_array = NamedArray([], np.datetime64("2000-01-01")) + named_array: NamedArray[np.ndarray] = NamedArray([], np.datetime64("2000-01-01")) assert named_array.dtype == np.dtype("datetime64[D]") @@ -139,7 +141,7 @@ def test_0d_datetime() -> None: ], ) def test_0d_timedelta(timedelta, expected_dtype: np.dtype) -> None: - named_array = NamedArray([], timedelta) + named_array: NamedArray[np.ndarray] = NamedArray([], timedelta) assert named_array.dtype == expected_dtype assert named_array.data == timedelta @@ -155,7 +157,7 @@ def test_0d_timedelta(timedelta, expected_dtype: np.dtype) -> None: ], ) def test_dims_setter(dims, data_shape, new_dims, raises: bool) -> None: - named_array = NamedArray(dims, np.random.random(data_shape)) + named_array: NamedArray[np.ndarray] = NamedArray(dims, np.random.random(data_shape)) assert named_array.dims == tuple(dims) if raises: with pytest.raises(ValueError): From 0d6cd2a39f61128e023628c4352f653537585a12 Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Date: Thu, 28 Sep 2023 09:27:42 -0700 Subject: [PATCH 21/46] Fix & normalize typing for chunks (#8247) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Fix & normalize typing for chunks I noticed that `"auto"` wasn't allowed as a value in a dict. So this normalizes all chunk types, and defines the mapping as containing the inner type. Allows removing some ignores (though also adds one). One question — not necessary to answer now — is whether we should allow a tuple of definitions, for each dimension. Generally we use names, which helps prevent mistakes, and allows us to be less concerned about dimension ordering. --- xarray/core/dataarray.py | 11 +++-------- xarray/core/dataset.py | 22 ++++++++++++++-------- xarray/core/types.py | 11 ++++++++--- xarray/core/variable.py | 4 ++-- 4 files changed, 27 insertions(+), 21 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 0b9786dc2b7..ef4389f3c6c 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -111,6 +111,7 @@ ReindexMethodOptions, Self, SideOptions, + T_Chunks, T_Xarray, ) from xarray.core.weighted import DataArrayWeighted @@ -1288,13 +1289,7 @@ def chunksizes(self) -> Mapping[Any, tuple[int, ...]]: def chunk( self, - chunks: ( - int - | Literal["auto"] - | tuple[int, ...] - | tuple[tuple[int, ...], ...] - | Mapping[Any, None | int | tuple[int, ...]] - ) = {}, # {} even though it's technically unsafe, is being used intentionally here (#4667) + chunks: T_Chunks = {}, # {} even though it's technically unsafe, is being used intentionally here (#4667) name_prefix: str = "xarray-", token: str | None = None, lock: bool = False, @@ -1362,7 +1357,7 @@ def chunk( if isinstance(chunks, (float, str, int)): # ignoring type; unclear why it won't accept a Literal into the value. - chunks = dict.fromkeys(self.dims, chunks) # type: ignore + chunks = dict.fromkeys(self.dims, chunks) elif isinstance(chunks, (tuple, list)): chunks = dict(zip(self.dims, chunks)) else: diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index d24a62414ea..9f08c13508e 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -93,7 +93,14 @@ is_duck_array, is_duck_dask_array, ) -from xarray.core.types import QuantileMethods, Self, T_DataArrayOrSet, T_Dataset +from xarray.core.types import ( + QuantileMethods, + Self, + T_ChunkDim, + T_Chunks, + T_DataArrayOrSet, + T_Dataset, +) from xarray.core.utils import ( Default, Frozen, @@ -1478,7 +1485,7 @@ def __iter__(self) -> Iterator[Hashable]: if TYPE_CHECKING: # needed because __getattr__ is returning Any and otherwise # this class counts as part of the SupportsArray Protocol - __array__ = None + __array__ = None # type: ignore[var-annotated,unused-ignore] else: @@ -2569,16 +2576,14 @@ def chunksizes(self) -> Mapping[Hashable, tuple[int, ...]]: def chunk( self, - chunks: ( - int | Literal["auto"] | Mapping[Any, None | int | str | tuple[int, ...]] - ) = {}, # {} even though it's technically unsafe, is being used intentionally here (#4667) + chunks: T_Chunks = {}, # {} even though it's technically unsafe, is being used intentionally here (#4667) name_prefix: str = "xarray-", token: str | None = None, lock: bool = False, inline_array: bool = False, chunked_array_type: str | ChunkManagerEntrypoint | None = None, from_array_kwargs=None, - **chunks_kwargs: None | int | str | tuple[int, ...], + **chunks_kwargs: T_ChunkDim, ) -> Self: """Coerce all arrays in this dataset into dask arrays with the given chunks. @@ -2637,8 +2642,9 @@ def chunk( ) chunks = {} - if isinstance(chunks, (Number, str, int)): - chunks = dict.fromkeys(self.dims, chunks) + if not isinstance(chunks, Mapping): + # We need to ignore since mypy doesn't recognize this can't be `None` + chunks = dict.fromkeys(self.dims, chunks) # type: ignore[arg-type] else: chunks = either_dict_or_kwargs(chunks, chunks_kwargs, "chunk") diff --git a/xarray/core/types.py b/xarray/core/types.py index bbcda7ca240..795283fa88b 100644 --- a/xarray/core/types.py +++ b/xarray/core/types.py @@ -19,9 +19,9 @@ try: if sys.version_info >= (3, 11): - from typing import Self + from typing import Self, TypeAlias else: - from typing_extensions import Self + from typing_extensions import Self, TypeAlias except ImportError: if TYPE_CHECKING: raise @@ -183,7 +183,12 @@ def copy( Dims = Union[str, Iterable[Hashable], "ellipsis", None] OrderedDims = Union[str, Sequence[Union[Hashable, "ellipsis"]], "ellipsis", None] -T_Chunks = Union[int, dict[Any, Any], Literal["auto"], None] +# FYI in some cases we don't allow `None`, which this doesn't take account of. +T_ChunkDim: TypeAlias = Union[int, Literal["auto"], None, tuple[int, ...]] +# We allow the tuple form of this (though arguably we could transition to named dims only) +T_Chunks: TypeAlias = Union[ + T_ChunkDim, Mapping[Any, T_ChunkDim], tuple[T_ChunkDim, ...] +] T_NormalizedChunks = tuple[tuple[int, ...], ...] DataVars = Mapping[Any, Any] diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 0e6e45d4929..4eeda073555 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -1035,7 +1035,7 @@ def chunk( data_old = self._data if chunkmanager.is_chunked_array(data_old): - data_chunked = chunkmanager.rechunk(data_old, chunks) # type: ignore[arg-type] + data_chunked = chunkmanager.rechunk(data_old, chunks) else: if isinstance(data_old, indexing.ExplicitlyIndexed): # Unambiguously handle array storage backends (like NetCDF4 and h5py) @@ -1057,7 +1057,7 @@ def chunk( data_chunked = chunkmanager.from_array( ndata, - chunks, # type: ignore[arg-type] + chunks, **_from_array_kwargs, ) From dbcf6a7245a1ba0c03b8ab490f9964f72f093185 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Thu, 28 Sep 2023 21:39:55 +0200 Subject: [PATCH 22/46] Add type hints to maybe_promote in dtypes.py (#8243) * Add type hints to maybe_promote * attempt to type hint fill_value * Update dtypes.py * Update dtypes.py * avoid type redefinition * I give upp with fill_value, pandas mostly do it as well. Only 1 place had the Scalar typing * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update types.py * Update dtypes.py * Update dtypes.py * Update variables.py --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- xarray/coding/variables.py | 2 ++ xarray/core/dtypes.py | 19 +++++++++++++------ 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/xarray/coding/variables.py b/xarray/coding/variables.py index d694c531b15..c583afc93c2 100644 --- a/xarray/coding/variables.py +++ b/xarray/coding/variables.py @@ -304,6 +304,8 @@ def decode(self, variable: Variable, name: T_Name = None): ) # special case DateTime to properly handle NaT + dtype: np.typing.DTypeLike + decoded_fill_value: Any if _is_time_like(attrs.get("units")) and data.dtype.kind in "iu": dtype, decoded_fill_value = np.int64, np.iinfo(np.int64).min else: diff --git a/xarray/core/dtypes.py b/xarray/core/dtypes.py index 0762fa03112..ccf84146819 100644 --- a/xarray/core/dtypes.py +++ b/xarray/core/dtypes.py @@ -1,6 +1,7 @@ from __future__ import annotations import functools +from typing import Any import numpy as np @@ -44,7 +45,7 @@ def __eq__(self, other): ) -def maybe_promote(dtype): +def maybe_promote(dtype: np.dtype) -> tuple[np.dtype, Any]: """Simpler equivalent of pandas.core.common._maybe_promote Parameters @@ -57,27 +58,33 @@ def maybe_promote(dtype): fill_value : Valid missing value for the promoted dtype. """ # N.B. these casting rules should match pandas + dtype_: np.typing.DTypeLike + fill_value: Any if np.issubdtype(dtype, np.floating): + dtype_ = dtype fill_value = np.nan elif np.issubdtype(dtype, np.timedelta64): # See https://github.com/numpy/numpy/issues/10685 # np.timedelta64 is a subclass of np.integer # Check np.timedelta64 before np.integer fill_value = np.timedelta64("NaT") + dtype_ = dtype elif np.issubdtype(dtype, np.integer): - dtype = np.float32 if dtype.itemsize <= 2 else np.float64 + dtype_ = np.float32 if dtype.itemsize <= 2 else np.float64 fill_value = np.nan elif np.issubdtype(dtype, np.complexfloating): + dtype_ = dtype fill_value = np.nan + np.nan * 1j elif np.issubdtype(dtype, np.datetime64): + dtype_ = dtype fill_value = np.datetime64("NaT") else: - dtype = object + dtype_ = object fill_value = np.nan - dtype = np.dtype(dtype) - fill_value = dtype.type(fill_value) - return dtype, fill_value + dtype_out = np.dtype(dtype_) + fill_value = dtype_out.type(fill_value) + return dtype_out, fill_value NAT_TYPES = {np.datetime64("NaT").dtype, np.timedelta64("NaT").dtype} From d6c37670d1076b7e8868fdeeedf4bd9b26fd7030 Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Date: Thu, 28 Sep 2023 13:01:12 -0700 Subject: [PATCH 23/46] Refine `chunks=None` handling (#8249) * Refine `chunks=None` handling Based on comment in #8247. This doesn't make it perfect, but allows the warning to get hit and clarifies the type comment, as a stop-gap * Test avoiding redefinition --------- Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com> --- xarray/core/dataset.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 9f08c13508e..459e2f3fce7 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -2634,21 +2634,20 @@ def chunk( xarray.unify_chunks dask.array.from_array """ - if chunks is None and chunks_kwargs is None: + if chunks is None and not chunks_kwargs: warnings.warn( "None value for 'chunks' is deprecated. " "It will raise an error in the future. Use instead '{}'", category=FutureWarning, ) chunks = {} - - if not isinstance(chunks, Mapping): - # We need to ignore since mypy doesn't recognize this can't be `None` - chunks = dict.fromkeys(self.dims, chunks) # type: ignore[arg-type] + chunks_mapping: Mapping[Any, Any] + if not isinstance(chunks, Mapping) and chunks is not None: + chunks_mapping = dict.fromkeys(self.dims, chunks) else: - chunks = either_dict_or_kwargs(chunks, chunks_kwargs, "chunk") + chunks_mapping = either_dict_or_kwargs(chunks, chunks_kwargs, "chunk") - bad_dims = chunks.keys() - self.dims.keys() + bad_dims = chunks_mapping.keys() - self.dims.keys() if bad_dims: raise ValueError( f"chunks keys {tuple(bad_dims)} not found in data dimensions {tuple(self.dims)}" @@ -2662,7 +2661,7 @@ def chunk( k: _maybe_chunk( k, v, - chunks, + chunks_mapping, token, lock, name_prefix, From a5f666bf137a63e5f5ba64d87619978222e84354 Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Date: Fri, 29 Sep 2023 09:39:34 -0700 Subject: [PATCH 24/46] Add modules to `check-untyped` (#8242) * Add modules to `check-untyped` In reviewing https://github.com/pydata/xarray/pull/8241, I realize that we actually want `check-untyped-defs`, which is a bit less strict, but lets us add some more modules on. I did have to add a couple of ignores, think it's a reasonable tradeoff to add big modules like `computation` on. Errors with this enabled are actual type errors, not just `mypy` pedanticness, so would be good to get as much as possible into this list... * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update xarray/core/computation.py Co-authored-by: Michael Niklas * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Michael Niklas --- pyproject.toml | 47 +++++++++++++++++++++++++++++++++----- xarray/core/computation.py | 11 +++++---- 2 files changed, 47 insertions(+), 11 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 25263928b20..e55b72341dd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -119,13 +119,48 @@ module = [ ] # Gradually we want to add more modules to this list, ratcheting up our total -# coverage. Once a module is here, functions require annotations in order to -# pass mypy. It would be especially useful to have tests here, because without -# annotating test functions, we don't have a great way of testing our type -# annotations — even with just `-> None` is sufficient for mypy to check them. +# coverage. Once a module is here, functions are checked by mypy regardless of +# whether they have type annotations. It would be especially useful to have test +# files listed here, because without them being checked, we don't have a great +# way of testing our annotations. [[tool.mypy.overrides]] -disallow_untyped_defs = true -module = ["xarray.core.rolling_exp"] +check_untyped_defs = true +module = [ + "xarray.core.accessor_dt", + "xarray.core.accessor_str", + "xarray.core.alignment", + "xarray.core.computation", + "xarray.core.rolling_exp", + "xarray.indexes.*", + "xarray.tests.*", +] +# This then excludes some modules from the above list. (So ideally we remove +# from here in time...) +[[tool.mypy.overrides]] +check_untyped_defs = false +module = [ + "xarray.tests.test_coarsen", + "xarray.tests.test_coding_times", + "xarray.tests.test_combine", + "xarray.tests.test_computation", + "xarray.tests.test_concat", + "xarray.tests.test_coordinates", + "xarray.tests.test_dask", + "xarray.tests.test_dataarray", + "xarray.tests.test_duck_array_ops", + "xarray.tests.test_groupby", + "xarray.tests.test_indexing", + "xarray.tests.test_merge", + "xarray.tests.test_missing", + "xarray.tests.test_parallelcompat", + "xarray.tests.test_plot", + "xarray.tests.test_sparse", + "xarray.tests.test_ufuncs", + "xarray.tests.test_units", + "xarray.tests.test_utils", + "xarray.tests.test_variable", + "xarray.tests.test_weighted", +] [tool.ruff] builtins = ["ellipsis"] diff --git a/xarray/core/computation.py b/xarray/core/computation.py index bae779af652..c707403db97 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -8,7 +8,7 @@ import operator import warnings from collections import Counter -from collections.abc import Hashable, Iterable, Mapping, Sequence, Set +from collections.abc import Hashable, Iterable, Iterator, Mapping, Sequence, Set from typing import TYPE_CHECKING, Any, Callable, Literal, TypeVar, Union, overload import numpy as np @@ -163,7 +163,7 @@ def to_gufunc_string(self, exclude_dims=frozenset()): if exclude_dims: exclude_dims = [self.dims_map[dim] for dim in exclude_dims] - counter = Counter() + counter: Counter = Counter() def _enumerate(dim): if dim in exclude_dims: @@ -571,7 +571,7 @@ def apply_groupby_func(func, *args): assert groupbys, "must have at least one groupby to iterate over" first_groupby = groupbys[0] (grouper,) = first_groupby.groupers - if any(not grouper.group.equals(gb.groupers[0].group) for gb in groupbys[1:]): + if any(not grouper.group.equals(gb.groupers[0].group) for gb in groupbys[1:]): # type: ignore[union-attr] raise ValueError( "apply_ufunc can only perform operations over " "multiple GroupBy objects at once if they are all " @@ -583,6 +583,7 @@ def apply_groupby_func(func, *args): iterators = [] for arg in args: + iterator: Iterator[Any] if isinstance(arg, GroupBy): iterator = (value for _, value in arg) elif hasattr(arg, "dims") and grouped_dim in arg.dims: @@ -597,9 +598,9 @@ def apply_groupby_func(func, *args): iterator = itertools.repeat(arg) iterators.append(iterator) - applied = (func(*zipped_args) for zipped_args in zip(*iterators)) + applied: Iterator = (func(*zipped_args) for zipped_args in zip(*iterators)) applied_example, applied = peek_at(applied) - combine = first_groupby._combine + combine = first_groupby._combine # type: ignore[attr-defined] if isinstance(applied_example, tuple): combined = tuple(combine(output) for output in zip(*applied)) else: From d8c166bd6070b532cec51d68dd1e92cd96b3db0e Mon Sep 17 00:00:00 2001 From: Mathias Hauser Date: Sat, 30 Sep 2023 03:26:34 +0200 Subject: [PATCH 25/46] update pytest config and un-xfail some tests (#8246) * update pytest config and un-xfail some tests * requires numbagg * requires dask * add reason * Update xarray/tests/test_variable.py * Update xarray/tests/test_units.py * Apply suggestions from code review * Update xarray/tests/test_backends.py --------- Co-authored-by: Anderson Banihirwe <13301940+andersy005@users.noreply.github.com> --- pyproject.toml | 4 ++- xarray/tests/test_backends.py | 40 ++++++++++++++-------------- xarray/tests/test_cftimeindex.py | 1 - xarray/tests/test_options.py | 1 - xarray/tests/test_rolling.py | 4 +-- xarray/tests/test_sparse.py | 13 ++++++--- xarray/tests/test_units.py | 24 ++++++++--------- xarray/tests/test_variable.py | 45 ++++++++++++++++++++------------ 8 files changed, 74 insertions(+), 58 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e55b72341dd..294b71ad671 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -190,7 +190,9 @@ select = [ known-first-party = ["xarray"] [tool.pytest.ini_options] -addopts = '--strict-markers' +addopts = ["--strict-config", "--strict-markers"] +log_cli_level = "INFO" +minversion = "7" filterwarnings = [ "ignore:Using a non-tuple sequence for multidimensional indexing is deprecated:FutureWarning", ] diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 5bd517098f1..0cbf3af3664 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -714,9 +714,6 @@ def multiple_indexing(indexers): ] multiple_indexing(indexers5) - @pytest.mark.xfail( - reason="zarr without dask handles negative steps in slices incorrectly", - ) def test_vectorized_indexing_negative_step(self) -> None: # use dask explicitly when present open_kwargs: dict[str, Any] | None @@ -1842,8 +1839,8 @@ def test_unsorted_index_raises(self) -> None: # dask first pulls items by block. pass + @pytest.mark.skip(reason="caching behavior differs for dask") def test_dataset_caching(self) -> None: - # caching behavior differs for dask pass def test_write_inconsistent_chunks(self) -> None: @@ -2261,9 +2258,6 @@ def test_encoding_kwarg_fixed_width_string(self) -> None: # not relevant for zarr, since we don't use EncodedStringCoder pass - # TODO: someone who understand caching figure out whether caching - # makes sense for Zarr backend - @pytest.mark.xfail(reason="Zarr caching not implemented") def test_dataset_caching(self) -> None: super().test_dataset_caching() @@ -2712,6 +2706,14 @@ def test_attributes(self, obj) -> None: with pytest.raises(TypeError, match=r"Invalid attribute in Dataset.attrs."): ds.to_zarr(store_target, **self.version_kwargs) + def test_vectorized_indexing_negative_step(self) -> None: + if not has_dask: + pytest.xfail( + reason="zarr without dask handles negative steps in slices incorrectly" + ) + + super().test_vectorized_indexing_negative_step() + @requires_zarr class TestZarrDictStore(ZarrBase): @@ -3378,8 +3380,8 @@ def roundtrip( ) as ds: yield ds + @pytest.mark.skip(reason="caching behavior differs for dask") def test_dataset_caching(self) -> None: - # caching behavior differs for dask pass def test_write_inconsistent_chunks(self) -> None: @@ -3982,7 +3984,6 @@ def test_open_mfdataset_raise_on_bad_combine_args(self) -> None: with pytest.raises(ValueError, match="`concat_dim` has no effect"): open_mfdataset([tmp1, tmp2], concat_dim="x") - @pytest.mark.xfail(reason="mfdataset loses encoding currently.") def test_encoding_mfdataset(self) -> None: original = Dataset( { @@ -4195,7 +4196,6 @@ def test_dataarray_compute(self) -> None: assert computed._in_memory assert_allclose(actual, computed, decode_bytes=False) - @pytest.mark.xfail def test_save_mfdataset_compute_false_roundtrip(self) -> None: from dask.delayed import Delayed @@ -5125,15 +5125,17 @@ def test_open_fsspec() -> None: ds2 = open_dataset(url, engine="zarr") xr.testing.assert_equal(ds0, ds2) - # multi dataset - url = "memory://out*.zarr" - ds2 = open_mfdataset(url, engine="zarr") - xr.testing.assert_equal(xr.concat([ds, ds0], dim="time"), ds2) - - # multi dataset with caching - url = "simplecache::memory://out*.zarr" - ds2 = open_mfdataset(url, engine="zarr") - xr.testing.assert_equal(xr.concat([ds, ds0], dim="time"), ds2) + # open_mfdataset requires dask + if has_dask: + # multi dataset + url = "memory://out*.zarr" + ds2 = open_mfdataset(url, engine="zarr") + xr.testing.assert_equal(xr.concat([ds, ds0], dim="time"), ds2) + + # multi dataset with caching + url = "simplecache::memory://out*.zarr" + ds2 = open_mfdataset(url, engine="zarr") + xr.testing.assert_equal(xr.concat([ds, ds0], dim="time"), ds2) @requires_h5netcdf diff --git a/xarray/tests/test_cftimeindex.py b/xarray/tests/test_cftimeindex.py index f58a6490632..1a1df6b81fe 100644 --- a/xarray/tests/test_cftimeindex.py +++ b/xarray/tests/test_cftimeindex.py @@ -1135,7 +1135,6 @@ def test_to_datetimeindex_feb_29(calendar): @requires_cftime -@pytest.mark.xfail(reason="https://github.com/pandas-dev/pandas/issues/24263") def test_multiindex(): index = xr.cftime_range("2001-01-01", periods=100, calendar="360_day") mindex = pd.MultiIndex.from_arrays([index]) diff --git a/xarray/tests/test_options.py b/xarray/tests/test_options.py index 3cecf1b52ec..8ad1cbe11be 100644 --- a/xarray/tests/test_options.py +++ b/xarray/tests/test_options.py @@ -165,7 +165,6 @@ def test_concat_attr_retention(self) -> None: result = concat([ds1, ds2], dim="dim1") assert result.attrs == original_attrs - @pytest.mark.xfail def test_merge_attr_retention(self) -> None: da1 = create_test_dataarray_attrs(var="var1") da2 = create_test_dataarray_attrs(var="var2") diff --git a/xarray/tests/test_rolling.py b/xarray/tests/test_rolling.py index 72d1b9071dd..2dc8ae24438 100644 --- a/xarray/tests/test_rolling.py +++ b/xarray/tests/test_rolling.py @@ -766,9 +766,7 @@ def test_ndrolling_construct(self, center, fill_value, dask) -> None: ) assert_allclose(actual, expected) - @pytest.mark.xfail( - reason="See https://github.com/pydata/xarray/pull/4369 or docstring" - ) + @requires_dask @pytest.mark.filterwarnings("error") @pytest.mark.parametrize("ds", (2,), indirect=True) @pytest.mark.parametrize("name", ("mean", "max")) diff --git a/xarray/tests/test_sparse.py b/xarray/tests/test_sparse.py index f64ce9338d7..489836b70fd 100644 --- a/xarray/tests/test_sparse.py +++ b/xarray/tests/test_sparse.py @@ -147,7 +147,6 @@ def test_variable_property(prop): ], ), True, - marks=xfail(reason="Coercion to dense"), ), param( do("conjugate"), @@ -201,7 +200,6 @@ def test_variable_property(prop): param( do("reduce", func="sum", dim="x"), True, - marks=xfail(reason="Coercion to dense"), ), param( do("rolling_window", dim="x", window=2, window_dim="x_win"), @@ -218,7 +216,7 @@ def test_variable_property(prop): param( do("var"), False, marks=xfail(reason="Missing implementation for np.nanvar") ), - param(do("to_dict"), False, marks=xfail(reason="Coercion to dense")), + param(do("to_dict"), False), (do("where", cond=make_xrvar({"x": 10, "y": 5}) > 0.5), True), ], ids=repr, @@ -237,7 +235,14 @@ def test_variable_method(func, sparse_output): assert isinstance(ret_s.data, sparse.SparseArray) assert np.allclose(ret_s.data.todense(), ret_d.data, equal_nan=True) else: - assert np.allclose(ret_s, ret_d, equal_nan=True) + if func.meth != "to_dict": + assert np.allclose(ret_s, ret_d) + else: + # pop the arrays from the dict + arr_s, arr_d = ret_s.pop("data"), ret_d.pop("data") + + assert np.allclose(arr_s, arr_d) + assert ret_s == ret_d @pytest.mark.parametrize( diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index addd7587544..d89a74e4fba 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -18,6 +18,7 @@ assert_identical, requires_dask, requires_matplotlib, + requires_numbagg, ) from xarray.tests.test_plot import PlotTestCase from xarray.tests.test_variable import _PAD_XR_NP_ARGS @@ -2548,7 +2549,6 @@ def test_univariate_ufunc(self, units, error, dtype): assert_units_equal(expected, actual) assert_identical(expected, actual) - @pytest.mark.xfail(reason="needs the type register system for __array_ufunc__") @pytest.mark.parametrize( "unit,error", ( @@ -3849,23 +3849,21 @@ def test_computation(self, func, variant, dtype): method("groupby", "x"), method("groupby_bins", "y", bins=4), method("coarsen", y=2), - pytest.param( - method("rolling", y=3), - marks=pytest.mark.xfail( - reason="numpy.lib.stride_tricks.as_strided converts to ndarray" - ), - ), - pytest.param( - method("rolling_exp", y=3), - marks=pytest.mark.xfail( - reason="numbagg functions are not supported by pint" - ), - ), + method("rolling", y=3), + pytest.param(method("rolling_exp", y=3), marks=requires_numbagg), method("weighted", xr.DataArray(data=np.linspace(0, 1, 10), dims="y")), ), ids=repr, ) def test_computation_objects(self, func, variant, dtype): + if variant == "data": + if func.name == "rolling_exp": + pytest.xfail(reason="numbagg functions are not supported by pint") + elif func.name == "rolling": + pytest.xfail( + reason="numpy.lib.stride_tricks.as_strided converts to ndarray" + ) + unit = unit_registry.m variants = { diff --git a/xarray/tests/test_variable.py b/xarray/tests/test_variable.py index 4fcd5f98d8f..f162b1c7d0a 100644 --- a/xarray/tests/test_variable.py +++ b/xarray/tests/test_variable.py @@ -885,20 +885,10 @@ def test_getitem_error(self): "mode", [ "mean", - pytest.param( - "median", - marks=pytest.mark.xfail(reason="median is not implemented by Dask"), - ), - pytest.param( - "reflect", marks=pytest.mark.xfail(reason="dask.array.pad bug") - ), + "median", + "reflect", "edge", - pytest.param( - "linear_ramp", - marks=pytest.mark.xfail( - reason="pint bug: https://github.com/hgrecco/pint/issues/1026" - ), - ), + "linear_ramp", "maximum", "minimum", "symmetric", @@ -2345,12 +2335,35 @@ def test_dask_rolling(self, dim, window, center): assert actual.shape == expected.shape assert_equal(actual, expected) - @pytest.mark.xfail( - reason="https://github.com/pydata/xarray/issues/6209#issuecomment-1025116203" - ) def test_multiindex(self): super().test_multiindex() + @pytest.mark.parametrize( + "mode", + [ + "mean", + pytest.param( + "median", + marks=pytest.mark.xfail(reason="median is not implemented by Dask"), + ), + pytest.param( + "reflect", marks=pytest.mark.xfail(reason="dask.array.pad bug") + ), + "edge", + "linear_ramp", + "maximum", + "minimum", + "symmetric", + "wrap", + ], + ) + @pytest.mark.parametrize("xr_arg, np_arg", _PAD_XR_NP_ARGS) + @pytest.mark.filterwarnings( + r"ignore:dask.array.pad.+? converts integers to floats." + ) + def test_pad(self, mode, xr_arg, np_arg): + super().test_pad(mode, xr_arg, np_arg) + @requires_sparse class TestVariableWithSparse: From f8ab40c5fc1424f9c66206ba9f00dc21735890af Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Date: Sat, 30 Sep 2023 11:50:33 -0700 Subject: [PATCH 26/46] Accept `lambda` for `other` param (#8256) * Accept `lambda` for `other` param --- doc/whats-new.rst | 4 ++++ xarray/core/common.py | 28 ++++++++++++++++------------ xarray/tests/test_dataarray.py | 8 ++++++++ 3 files changed, 28 insertions(+), 12 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 17744288aef..e485b24bf3e 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -22,6 +22,10 @@ v2023.09.1 (unreleased) New Features ~~~~~~~~~~~~ +- :py:meth:`DataArray.where` & :py:meth:`Dataset.where` accept a callable for + the ``other`` parameter, passing the object as the first argument. Previously, + this was only valid for the ``cond`` parameter. (:issue:`8255`) + By `Maximilian Roos `_. Breaking changes ~~~~~~~~~~~~~~~~ diff --git a/xarray/core/common.py b/xarray/core/common.py index db9b2aead23..2a4c4c200d4 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -1074,9 +1074,10 @@ def where(self, cond: Any, other: Any = dtypes.NA, drop: bool = False) -> Self: cond : DataArray, Dataset, or callable Locations at which to preserve this object's values. dtype must be `bool`. If a callable, it must expect this object as its only parameter. - other : scalar, DataArray or Dataset, optional + other : scalar, DataArray, Dataset, or callable, optional Value to use for locations in this object where ``cond`` is False. - By default, these locations filled with NA. + By default, these locations are filled with NA. If a callable, it must + expect this object as its only parameter. drop : bool, default: False If True, coordinate labels that only correspond to False values of the condition are dropped from the result. @@ -1124,7 +1125,16 @@ def where(self, cond: Any, other: Any = dtypes.NA, drop: bool = False) -> Self: [15., nan, nan, nan]]) Dimensions without coordinates: x, y - >>> a.where(lambda x: x.x + x.y < 4, drop=True) + >>> a.where(lambda x: x.x + x.y < 4, lambda x: -x) + + array([[ 0, 1, 2, 3, -4], + [ 5, 6, 7, -8, -9], + [ 10, 11, -12, -13, -14], + [ 15, -16, -17, -18, -19], + [-20, -21, -22, -23, -24]]) + Dimensions without coordinates: x, y + + >>> a.where(a.x + a.y < 4, drop=True) array([[ 0., 1., 2., 3.], [ 5., 6., 7., nan], @@ -1132,14 +1142,6 @@ def where(self, cond: Any, other: Any = dtypes.NA, drop: bool = False) -> Self: [15., nan, nan, nan]]) Dimensions without coordinates: x, y - >>> a.where(a.x + a.y < 4, -1, drop=True) - - array([[ 0, 1, 2, 3], - [ 5, 6, 7, -1], - [10, 11, -1, -1], - [15, -1, -1, -1]]) - Dimensions without coordinates: x, y - See Also -------- numpy.where : corresponding numpy function @@ -1151,11 +1153,13 @@ def where(self, cond: Any, other: Any = dtypes.NA, drop: bool = False) -> Self: if callable(cond): cond = cond(self) + if callable(other): + other = other(self) if drop: if not isinstance(cond, (Dataset, DataArray)): raise TypeError( - f"cond argument is {cond!r} but must be a {Dataset!r} or {DataArray!r}" + f"cond argument is {cond!r} but must be a {Dataset!r} or {DataArray!r} (or a callable than returns one)." ) self, cond = align(self, cond) # type: ignore[assignment] diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 11ebc4da347..63175f2be40 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -2717,6 +2717,14 @@ def test_where_lambda(self) -> None: actual = arr.where(lambda x: x.y < 2, drop=True) assert_identical(actual, expected) + def test_where_other_lambda(self) -> None: + arr = DataArray(np.arange(4), dims="y") + expected = xr.concat( + [arr.sel(y=slice(2)), arr.sel(y=slice(2, None)) + 1], dim="y" + ) + actual = arr.where(lambda x: x.y < 2, lambda x: x + 1) + assert_identical(actual, expected) + def test_where_string(self) -> None: array = DataArray(["a", "b"]) expected = DataArray(np.array(["a", np.nan], dtype=object)) From 26b5fe2a3defbd88793d38aff5c45abf5d1e2163 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 2 Oct 2023 22:52:42 +0200 Subject: [PATCH 27/46] [pre-commit.ci] pre-commit autoupdate (#8262) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [pre-commit.ci] pre-commit autoupdate updates: - [github.com/astral-sh/ruff-pre-commit: v0.0.287 → v0.0.292](https://github.com/astral-sh/ruff-pre-commit/compare/v0.0.287...v0.0.292) - [github.com/psf/black: 23.7.0 → 23.9.1](https://github.com/psf/black/compare/23.7.0...23.9.1) * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c2586a12aa2..5626f450ec0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -18,13 +18,13 @@ repos: files: ^xarray/ - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: 'v0.0.287' + rev: 'v0.0.292' hooks: - id: ruff args: ["--fix"] # https://github.com/python/black#version-control-integration - repo: https://github.com/psf/black - rev: 23.7.0 + rev: 23.9.1 hooks: - id: black-jupyter - repo: https://github.com/keewis/blackdoc @@ -32,7 +32,7 @@ repos: hooks: - id: blackdoc exclude: "generate_aggregations.py" - additional_dependencies: ["black==23.7.0"] + additional_dependencies: ["black==23.9.1"] - id: blackdoc-autoupdate-black - repo: https://github.com/pre-commit/mirrors-mypy rev: v1.5.1 From a1d168d5e706446c68f79ae9923d2d2c34778c14 Mon Sep 17 00:00:00 2001 From: Pieter Eendebak Date: Tue, 3 Oct 2023 00:09:53 +0200 Subject: [PATCH 28/46] Update type annotation for center argument of dataaray_plot methods (#8261) * Update type annotation for center argument of dataaray_plot methods * address review comments --- doc/whats-new.rst | 2 ++ xarray/plot/dataarray_plot.py | 34 +++++++++++++++++----------------- 2 files changed, 19 insertions(+), 17 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index e485b24bf3e..63c9dee04c5 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -139,6 +139,8 @@ Bug fixes - Static typing of dunder ops methods (like :py:meth:`DataArray.__eq__`) has been fixed. Remaining issues are upstream problems (:issue:`7780`, :pull:`8204`). By `Michael Niklas `_. +- Fix type annotation for ``center`` argument of plotting methods (like :py:meth:`xarray.plot.dataarray_plot.pcolormesh`) (:pull:`8261`). + By `Pieter Eendebak `_. Documentation ~~~~~~~~~~~~~ diff --git a/xarray/plot/dataarray_plot.py b/xarray/plot/dataarray_plot.py index 8e930d0731c..61f2014fbc3 100644 --- a/xarray/plot/dataarray_plot.py +++ b/xarray/plot/dataarray_plot.py @@ -1348,7 +1348,7 @@ def _plot2d(plotfunc): `seaborn color palette `_. Note: if ``cmap`` is a seaborn color palette and the plot type is not ``'contour'`` or ``'contourf'``, ``levels`` must also be specified. - center : float, optional + center : float or False, optional The value at which to center the colormap. Passing this value implies use of a diverging colormap. Setting it to ``False`` prevents use of a diverging colormap. @@ -1432,7 +1432,7 @@ def newplotfunc( vmin: float | None = None, vmax: float | None = None, cmap: str | Colormap | None = None, - center: float | None = None, + center: float | Literal[False] | None = None, robust: bool = False, extend: ExtendOptions = None, levels: ArrayLike | None = None, @@ -1692,7 +1692,7 @@ def imshow( # type: ignore[misc,unused-ignore] # None is hashable :( vmin: float | None = None, vmax: float | None = None, cmap: str | Colormap | None = None, - center: float | None = None, + center: float | Literal[False] | None = None, robust: bool = False, extend: ExtendOptions = None, levels: ArrayLike | None = None, @@ -1733,7 +1733,7 @@ def imshow( vmin: float | None = None, vmax: float | None = None, cmap: str | Colormap | None = None, - center: float | None = None, + center: float | Literal[False] | None = None, robust: bool = False, extend: ExtendOptions = None, levels: ArrayLike | None = None, @@ -1774,7 +1774,7 @@ def imshow( vmin: float | None = None, vmax: float | None = None, cmap: str | Colormap | None = None, - center: float | None = None, + center: float | Literal[False] | None = None, robust: bool = False, extend: ExtendOptions = None, levels: ArrayLike | None = None, @@ -1911,7 +1911,7 @@ def contour( # type: ignore[misc,unused-ignore] # None is hashable :( vmin: float | None = None, vmax: float | None = None, cmap: str | Colormap | None = None, - center: float | None = None, + center: float | Literal[False] | None = None, robust: bool = False, extend: ExtendOptions = None, levels: ArrayLike | None = None, @@ -1952,7 +1952,7 @@ def contour( vmin: float | None = None, vmax: float | None = None, cmap: str | Colormap | None = None, - center: float | None = None, + center: float | Literal[False] | None = None, robust: bool = False, extend: ExtendOptions = None, levels: ArrayLike | None = None, @@ -1993,7 +1993,7 @@ def contour( vmin: float | None = None, vmax: float | None = None, cmap: str | Colormap | None = None, - center: float | None = None, + center: float | Literal[False] | None = None, robust: bool = False, extend: ExtendOptions = None, levels: ArrayLike | None = None, @@ -2047,7 +2047,7 @@ def contourf( # type: ignore[misc,unused-ignore] # None is hashable :( vmin: float | None = None, vmax: float | None = None, cmap: str | Colormap | None = None, - center: float | None = None, + center: float | Literal[False] | None = None, robust: bool = False, extend: ExtendOptions = None, levels: ArrayLike | None = None, @@ -2088,7 +2088,7 @@ def contourf( vmin: float | None = None, vmax: float | None = None, cmap: str | Colormap | None = None, - center: float | None = None, + center: float | Literal[False] | None = None, robust: bool = False, extend: ExtendOptions = None, levels: ArrayLike | None = None, @@ -2129,7 +2129,7 @@ def contourf( vmin: float | None = None, vmax: float | None = None, cmap: str | Colormap | None = None, - center: float | None = None, + center: float | Literal[False] | None = None, robust: bool = False, extend: ExtendOptions = None, levels: ArrayLike | None = None, @@ -2183,7 +2183,7 @@ def pcolormesh( # type: ignore[misc,unused-ignore] # None is hashable :( vmin: float | None = None, vmax: float | None = None, cmap: str | Colormap | None = None, - center: float | None = None, + center: float | Literal[False] | None = None, robust: bool = False, extend: ExtendOptions = None, levels: ArrayLike | None = None, @@ -2224,7 +2224,7 @@ def pcolormesh( vmin: float | None = None, vmax: float | None = None, cmap: str | Colormap | None = None, - center: float | None = None, + center: float | Literal[False] | None = None, robust: bool = False, extend: ExtendOptions = None, levels: ArrayLike | None = None, @@ -2265,7 +2265,7 @@ def pcolormesh( vmin: float | None = None, vmax: float | None = None, cmap: str | Colormap | None = None, - center: float | None = None, + center: float | Literal[False] | None = None, robust: bool = False, extend: ExtendOptions = None, levels: ArrayLike | None = None, @@ -2370,7 +2370,7 @@ def surface( vmin: float | None = None, vmax: float | None = None, cmap: str | Colormap | None = None, - center: float | None = None, + center: float | Literal[False] | None = None, robust: bool = False, extend: ExtendOptions = None, levels: ArrayLike | None = None, @@ -2411,7 +2411,7 @@ def surface( vmin: float | None = None, vmax: float | None = None, cmap: str | Colormap | None = None, - center: float | None = None, + center: float | Literal[False] | None = None, robust: bool = False, extend: ExtendOptions = None, levels: ArrayLike | None = None, @@ -2452,7 +2452,7 @@ def surface( vmin: float | None = None, vmax: float | None = None, cmap: str | Colormap | None = None, - center: float | None = None, + center: float | Literal[False] | None = None, robust: bool = False, extend: ExtendOptions = None, levels: ArrayLike | None = None, From d5f17858e5739c986bfb52e7f2ad106bb4489364 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 3 Oct 2023 19:18:41 +0200 Subject: [PATCH 29/46] Use strict type hinting for namedarray (#8241) * Disallow untyped defs in namedarray * Just use strict instead * Update pyproject.toml * Test explicit list instead. * Update pyproject.toml * Update pyproject.toml * Update pyproject.toml * Update utils.py * Update core.py * getmaskarray isn't typed yet * Update core.py * add _Array protocol * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update utils.py * Update utils.py * Update utils.py * Update utils.py * Update utils.py * Update utils.py * Update utils.py * Update utils.py * Update utils.py * Update utils.py * Update utils.py * Update test_namedarray.py * Update utils.py * Update test_namedarray.py * Update test_namedarray.py * Update utils.py * Update utils.py * Update utils.py * Update core.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update utils.py * Update core.py * Update test_namedarray.py * Update test_namedarray.py * Update test_namedarray.py * Update utils.py * Update core.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update core.py * Update core.py * Update utils.py * Update core.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update core.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update core.py * Update test_namedarray.py * Update utils.py * Update pyproject.toml * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update core.py * Update utils.py * Update xarray/namedarray/utils.py Co-authored-by: Michael Niklas * Update utils.py * Update core.py * Update utils.py * Update core.py * Update utils.py * Update core.py * Update core.py * Update core.py * Update test_namedarray.py * Update utils.py * Update core.py * Update utils.py * Update test_namedarray.py * Update test_namedarray.py * Update core.py * Update parallel.py * Update utils.py * fixes * Update utils.py * Update utils.py * ignores * Update xarray/namedarray/utils.py Co-authored-by: Michael Niklas * Update xarray/namedarray/utils.py Co-authored-by: Michael Niklas * Update core.py * Update test_namedarray.py * Update core.py * Update core.py * Update core.py * Update core.py * Update test_namedarray.py * Update core.py * Update test_namedarray.py * import specific type functions * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update core.py * Update core.py * Update core.py * Try chunkedarray instead * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixes * Update core.py * Update core.py * Update core.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update core.py --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Michael Niklas Co-authored-by: Anderson Banihirwe <13301940+andersy005@users.noreply.github.com> --- pyproject.toml | 36 ++++++ xarray/core/parallel.py | 2 +- xarray/namedarray/core.py | 188 +++++++++++++++++++++----------- xarray/namedarray/utils.py | 88 +++++++++++++-- xarray/tests/test_namedarray.py | 88 +++++++++++---- 5 files changed, 304 insertions(+), 98 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 294b71ad671..e24f88d9679 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -93,6 +93,7 @@ module = [ "cftime.*", "cubed.*", "cupy.*", + "dask.types.*", "fsspec.*", "h5netcdf.*", "h5py.*", @@ -162,6 +163,41 @@ module = [ "xarray.tests.test_weighted", ] +# Use strict = true whenever namedarray has become standalone. In the meantime +# don't forget to add all new files related to namedarray here: +# ref: https://mypy.readthedocs.io/en/stable/existing_code.html#introduce-stricter-options +[[tool.mypy.overrides]] +# Start off with these +warn_unused_configs = true +warn_redundant_casts = true +warn_unused_ignores = true + +# Getting these passing should be easy +strict_equality = true +strict_concatenate = true + +# Strongly recommend enabling this one as soon as you can +check_untyped_defs = true + +# These shouldn't be too much additional work, but may be tricky to +# get passing if you use a lot of untyped libraries +disallow_subclassing_any = true +disallow_untyped_decorators = true +disallow_any_generics = true + +# These next few are various gradations of forcing use of type annotations +disallow_untyped_calls = true +disallow_incomplete_defs = true +disallow_untyped_defs = true + +# This one isn't too hard to get passing, but return on investment is lower +no_implicit_reexport = true + +# This one can be tricky to get passing if you use a lot of untyped libraries +warn_return_any = true + +module = ["xarray.namedarray.*", "xarray.tests.test_namedarray"] + [tool.ruff] builtins = ["ellipsis"] exclude = [ diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 07c3c606bf2..949576b4ee8 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -443,7 +443,7 @@ def subset_dataset_to_block( for dim in variable.dims: chunk = chunk[chunk_index[dim]] - chunk_variable_task = (f"{name}-{gname}-{chunk[0]}",) + chunk_tuple + chunk_variable_task = (f"{name}-{gname}-{chunk[0]!r}",) + chunk_tuple graph[chunk_variable_task] = ( tuple, [variable.dims, chunk, variable.attrs], diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index 03bfa16682d..9b7aff9d067 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -2,39 +2,46 @@ import copy import math -import sys -import typing -from collections.abc import Hashable, Iterable, Mapping +from collections.abc import Hashable, Iterable, Mapping, Sequence +from typing import TYPE_CHECKING, Any, Callable, Generic, Union, cast import numpy as np # TODO: get rid of this after migrating this class to array API from xarray.core import dtypes from xarray.core.indexing import ExplicitlyIndexed -from xarray.core.utils import Default, _default from xarray.namedarray.utils import ( + Default, T_DuckArray, + _default, + is_chunked_duck_array, is_duck_array, is_duck_dask_array, to_0d_object_array, ) -if typing.TYPE_CHECKING: - T_NamedArray = typing.TypeVar("T_NamedArray", bound="NamedArray") - DimsInput = typing.Union[str, Iterable[Hashable]] - Dims = tuple[Hashable, ...] - +if TYPE_CHECKING: + from xarray.namedarray.utils import Self # type: ignore[attr-defined] -try: - if sys.version_info >= (3, 11): - from typing import Self - else: - from typing_extensions import Self -except ImportError: - if typing.TYPE_CHECKING: - raise - else: - Self: typing.Any = None + try: + from dask.typing import ( + Graph, + NestedKeys, + PostComputeCallable, + PostPersistCallable, + SchedulerGetCallable, + ) + except ImportError: + Graph: Any # type: ignore[no-redef] + NestedKeys: Any # type: ignore[no-redef] + SchedulerGetCallable: Any # type: ignore[no-redef] + PostComputeCallable: Any # type: ignore[no-redef] + PostPersistCallable: Any # type: ignore[no-redef] + + # T_NamedArray = TypeVar("T_NamedArray", bound="NamedArray[T_DuckArray]") + DimsInput = Union[str, Iterable[Hashable]] + Dims = tuple[Hashable, ...] + AttrsInput = Union[Mapping[Any, Any], None] # TODO: Add tests! @@ -43,44 +50,48 @@ def as_compatible_data( ) -> T_DuckArray: if fastpath and getattr(data, "ndim", 0) > 0: # can't use fastpath (yet) for scalars - return typing.cast(T_DuckArray, data) + return cast(T_DuckArray, data) if isinstance(data, np.ma.MaskedArray): - mask = np.ma.getmaskarray(data) + mask = np.ma.getmaskarray(data) # type: ignore[no-untyped-call] if mask.any(): # TODO: requires refactoring/vendoring xarray.core.dtypes and xarray.core.duck_array_ops raise NotImplementedError("MaskedArray is not supported yet") else: - return typing.cast(T_DuckArray, np.asarray(data)) + return cast(T_DuckArray, np.asarray(data)) if is_duck_array(data): return data if isinstance(data, NamedArray): - return typing.cast(T_DuckArray, data.data) + return cast(T_DuckArray, data.data) if isinstance(data, ExplicitlyIndexed): # TODO: better that is_duck_array(ExplicitlyIndexed) -> True - return typing.cast(T_DuckArray, data) + return cast(T_DuckArray, data) if isinstance(data, tuple): data = to_0d_object_array(data) # validate whether the data is valid data types. - return typing.cast(T_DuckArray, np.asarray(data)) + return cast(T_DuckArray, np.asarray(data)) -class NamedArray(typing.Generic[T_DuckArray]): +class NamedArray(Generic[T_DuckArray]): """A lightweight wrapper around duck arrays with named dimensions and attributes which describe a single Array. Numeric operations on this object implement array broadcasting and dimension alignment based on dimension names, rather than axis order.""" - __slots__ = ("_dims", "_data", "_attrs") + __slots__ = ("_data", "_dims", "_attrs") + + _data: T_DuckArray + _dims: Dims + _attrs: dict[Any, Any] | None def __init__( self, dims: DimsInput, data: T_DuckArray | np.typing.ArrayLike, - attrs: dict | None = None, + attrs: AttrsInput = None, fastpath: bool = False, ): """ @@ -105,9 +116,9 @@ def __init__( """ - self._data: T_DuckArray = as_compatible_data(data, fastpath=fastpath) - self._dims: Dims = self._parse_dimensions(dims) - self._attrs: dict | None = dict(attrs) if attrs else None + self._data = as_compatible_data(data, fastpath=fastpath) + self._dims = self._parse_dimensions(dims) + self._attrs = dict(attrs) if attrs else None @property def ndim(self) -> int: @@ -140,7 +151,7 @@ def __len__(self) -> int: raise TypeError("len() of unsized object") from exc @property - def dtype(self) -> np.dtype: + def dtype(self) -> np.dtype[Any]: """ Data-type of the array’s elements. @@ -178,7 +189,7 @@ def nbytes(self) -> int: the bytes consumed based on the ``size`` and ``dtype``. """ if hasattr(self._data, "nbytes"): - return self._data.nbytes + return self._data.nbytes # type: ignore[no-any-return] else: return self.size * self.dtype.itemsize @@ -201,14 +212,14 @@ def _parse_dimensions(self, dims: DimsInput) -> Dims: return dims @property - def attrs(self) -> dict[typing.Any, typing.Any]: + def attrs(self) -> dict[Any, Any]: """Dictionary of local attributes on this NamedArray.""" if self._attrs is None: self._attrs = {} return self._attrs @attrs.setter - def attrs(self, value: Mapping) -> None: + def attrs(self, value: Mapping[Any, Any]) -> None: self._attrs = dict(value) def _check_shape(self, new_data: T_DuckArray) -> None: @@ -256,43 +267,84 @@ def imag(self) -> Self: """ return self._replace(data=self.data.imag) - def __dask_tokenize__(self): + def __dask_tokenize__(self) -> Hashable | None: # Use v.data, instead of v._data, in order to cope with the wrappers # around NetCDF and the like from dask.base import normalize_token - return normalize_token((type(self), self._dims, self.data, self.attrs)) + s, d, a, attrs = type(self), self._dims, self.data, self.attrs + return normalize_token((s, d, a, attrs)) # type: ignore[no-any-return] - def __dask_graph__(self): - return self._data.__dask_graph__() if is_duck_dask_array(self._data) else None + def __dask_graph__(self) -> Graph | None: + if is_duck_dask_array(self._data): + return self._data.__dask_graph__() + else: + # TODO: Should this method just raise instead? + # raise NotImplementedError("Method requires self.data to be a dask array") + return None - def __dask_keys__(self): - return self._data.__dask_keys__() + def __dask_keys__(self) -> NestedKeys: + if is_duck_dask_array(self._data): + return self._data.__dask_keys__() + else: + raise AttributeError("Method requires self.data to be a dask array.") - def __dask_layers__(self): - return self._data.__dask_layers__() + def __dask_layers__(self) -> Sequence[str]: + if is_duck_dask_array(self._data): + return self._data.__dask_layers__() + else: + raise AttributeError("Method requires self.data to be a dask array.") @property - def __dask_optimize__(self) -> typing.Callable: - return self._data.__dask_optimize__ + def __dask_optimize__( + self, + ) -> Callable[..., dict[Any, Any]]: + if is_duck_dask_array(self._data): + return self._data.__dask_optimize__ # type: ignore[no-any-return] + else: + raise AttributeError("Method requires self.data to be a dask array.") @property - def __dask_scheduler__(self) -> typing.Callable: - return self._data.__dask_scheduler__ + def __dask_scheduler__(self) -> SchedulerGetCallable: + if is_duck_dask_array(self._data): + return self._data.__dask_scheduler__ + else: + raise AttributeError("Method requires self.data to be a dask array.") def __dask_postcompute__( self, - ) -> tuple[typing.Callable, tuple[typing.Any, ...]]: - array_func, array_args = self._data.__dask_postcompute__() - return self._dask_finalize, (array_func,) + array_args + ) -> tuple[PostComputeCallable, tuple[Any, ...]]: + if is_duck_dask_array(self._data): + array_func, array_args = self._data.__dask_postcompute__() # type: ignore[no-untyped-call] + return self._dask_finalize, (array_func,) + array_args + else: + raise AttributeError("Method requires self.data to be a dask array.") def __dask_postpersist__( self, - ) -> tuple[typing.Callable, tuple[typing.Any, ...]]: - array_func, array_args = self._data.__dask_postpersist__() - return self._dask_finalize, (array_func,) + array_args + ) -> tuple[ + Callable[ + [Graph, PostPersistCallable[Any], Any, Any], + Self, + ], + tuple[Any, ...], + ]: + if is_duck_dask_array(self._data): + a: tuple[PostPersistCallable[Any], tuple[Any, ...]] + a = self._data.__dask_postpersist__() # type: ignore[no-untyped-call] + array_func, array_args = a + + return self._dask_finalize, (array_func,) + array_args + else: + raise AttributeError("Method requires self.data to be a dask array.") - def _dask_finalize(self, results, array_func, *args, **kwargs) -> Self: + def _dask_finalize( + self, + results: Graph, + array_func: PostPersistCallable[Any], + *args: Any, + **kwargs: Any, + ) -> Self: data = array_func(results, *args, **kwargs) return type(self)(self._dims, data, attrs=self._attrs) @@ -308,12 +360,16 @@ def chunks(self) -> tuple[tuple[int, ...], ...] | None: NamedArray.chunksizes xarray.unify_chunks """ - return getattr(self._data, "chunks", None) + data = self._data + if is_chunked_duck_array(data): + return data.chunks + else: + return None @property def chunksizes( self, - ) -> typing.Mapping[typing.Any, tuple[int, ...]]: + ) -> Mapping[Any, tuple[int, ...]]: """ Mapping from dimension names to block lengths for this namedArray's data, or None if the underlying data is not a dask array. @@ -328,8 +384,9 @@ def chunksizes( NamedArray.chunks xarray.unify_chunks """ - if hasattr(self._data, "chunks"): - return dict(zip(self.dims, self.data.chunks)) + data = self._data + if is_chunked_duck_array(data): + return dict(zip(self.dims, data.chunks)) else: return {} @@ -338,7 +395,12 @@ def sizes(self) -> dict[Hashable, int]: """Ordered mapping from dimension names to lengths.""" return dict(zip(self.dims, self.shape)) - def _replace(self, dims=_default, data=_default, attrs=_default) -> Self: + def _replace( + self, + dims: DimsInput | Default = _default, + data: T_DuckArray | np.typing.ArrayLike | Default = _default, + attrs: AttrsInput | Default = _default, + ) -> Self: if dims is _default: dims = copy.copy(self._dims) if data is _default: @@ -351,7 +413,7 @@ def _copy( self, deep: bool = True, data: T_DuckArray | np.typing.ArrayLike | None = None, - memo: dict[int, typing.Any] | None = None, + memo: dict[int, Any] | None = None, ) -> Self: if data is None: ndata = self._data @@ -370,7 +432,7 @@ def _copy( def __copy__(self) -> Self: return self._copy(deep=False) - def __deepcopy__(self, memo: dict[int, typing.Any] | None = None) -> Self: + def __deepcopy__(self, memo: dict[int, Any] | None = None) -> Self: return self._copy(deep=True, memo=memo) def copy( @@ -415,7 +477,7 @@ def _nonzero(self) -> tuple[Self, ...]: def _as_sparse( self, sparse_format: str | Default = _default, - fill_value=dtypes.NA, + fill_value: np.typing.ArrayLike | Default = _default, ) -> Self: """ use sparse-array as backend. @@ -423,7 +485,7 @@ def _as_sparse( import sparse # TODO: what to do if dask-backended? - if fill_value is dtypes.NA: + if fill_value is _default: dtype, fill_value = dtypes.maybe_promote(self.dtype) else: dtype = dtypes.result_type(self.dtype, fill_value) diff --git a/xarray/namedarray/utils.py b/xarray/namedarray/utils.py index 1495e111d85..c77009aeb2d 100644 --- a/xarray/namedarray/utils.py +++ b/xarray/namedarray/utils.py @@ -2,19 +2,79 @@ import importlib import sys -import typing +from enum import Enum +from typing import TYPE_CHECKING, Any, Final, Protocol, TypeVar import numpy as np -if typing.TYPE_CHECKING: +if TYPE_CHECKING: if sys.version_info >= (3, 10): from typing import TypeGuard else: from typing_extensions import TypeGuard + if sys.version_info >= (3, 11): + from typing import Self + else: + from typing_extensions import Self + + try: + from dask.array import Array as DaskArray + from dask.types import DaskCollection + except ImportError: + DaskArray = np.ndarray # type: ignore + DaskCollection: Any = np.ndarray # type: ignore + + +# https://stackoverflow.com/questions/74633074/how-to-type-hint-a-generic-numpy-array +T_DType_co = TypeVar("T_DType_co", bound=np.dtype[np.generic], covariant=True) +# T_DType = TypeVar("T_DType", bound=np.dtype[np.generic]) + + +class _Array(Protocol[T_DType_co]): + @property + def dtype(self) -> T_DType_co: + ... + + @property + def shape(self) -> tuple[int, ...]: + ... + + @property + def real(self) -> Self: + ... + + @property + def imag(self) -> Self: + ... + + def astype(self, dtype: np.typing.DTypeLike) -> Self: + ... + + # TODO: numpy doesn't use any inputs: + # https://github.com/numpy/numpy/blob/v1.24.3/numpy/_typing/_array_like.py#L38 + def __array__(self) -> np.ndarray[Any, T_DType_co]: + ... + + +class _ChunkedArray(_Array[T_DType_co], Protocol[T_DType_co]): + @property + def chunks(self) -> tuple[tuple[int, ...], ...]: + ... + + # temporary placeholder for indicating an array api compliant type. # hopefully in the future we can narrow this down more -T_DuckArray = typing.TypeVar("T_DuckArray", bound=typing.Any) +T_DuckArray = TypeVar("T_DuckArray", bound=_Array[np.dtype[np.generic]]) +T_ChunkedArray = TypeVar("T_ChunkedArray", bound=_ChunkedArray[np.dtype[np.generic]]) + + +# Singleton type, as per https://github.com/python/typing/pull/240 +class Default(Enum): + token: Final = 0 + + +_default = Default.token def module_available(module: str) -> bool: @@ -35,15 +95,15 @@ def module_available(module: str) -> bool: return importlib.util.find_spec(module) is not None -def is_dask_collection(x: typing.Any) -> bool: +def is_dask_collection(x: object) -> TypeGuard[DaskCollection]: if module_available("dask"): - from dask.base import is_dask_collection + from dask.typing import DaskCollection - return is_dask_collection(x) + return isinstance(x, DaskCollection) return False -def is_duck_array(value: typing.Any) -> TypeGuard[T_DuckArray]: +def is_duck_array(value: object) -> TypeGuard[T_DuckArray]: if isinstance(value, np.ndarray): return True return ( @@ -57,11 +117,19 @@ def is_duck_array(value: typing.Any) -> TypeGuard[T_DuckArray]: ) -def is_duck_dask_array(x: typing.Any) -> bool: - return is_duck_array(x) and is_dask_collection(x) +def is_duck_dask_array(x: T_DuckArray) -> TypeGuard[DaskArray]: + return is_dask_collection(x) + + +def is_chunked_duck_array( + x: T_DuckArray, +) -> TypeGuard[_ChunkedArray[np.dtype[np.generic]]]: + return hasattr(x, "chunks") -def to_0d_object_array(value: typing.Any) -> np.ndarray: +def to_0d_object_array( + value: object, +) -> np.ndarray[Any, np.dtype[np.object_]]: """Given a value, wrap it in a 0-D numpy.ndarray with dtype=object.""" result = np.empty((), dtype=object) result[()] = value diff --git a/xarray/tests/test_namedarray.py b/xarray/tests/test_namedarray.py index 9d37a6c794c..ea1588bf554 100644 --- a/xarray/tests/test_namedarray.py +++ b/xarray/tests/test_namedarray.py @@ -1,3 +1,7 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + import numpy as np import pytest @@ -5,9 +9,12 @@ from xarray.namedarray.core import NamedArray, as_compatible_data from xarray.namedarray.utils import T_DuckArray +if TYPE_CHECKING: + from xarray.namedarray.utils import Self # type: ignore[attr-defined] + @pytest.fixture -def random_inputs() -> np.ndarray: +def random_inputs() -> np.ndarray[Any, np.dtype[np.float32]]: return np.arange(3 * 4 * 5, dtype=np.float32).reshape((3, 4, 5)) @@ -28,7 +35,7 @@ def test_as_compatible_data( def test_as_compatible_data_with_masked_array() -> None: - masked_array = np.ma.array([1, 2, 3], mask=[False, True, False]) + masked_array = np.ma.array([1, 2, 3], mask=[False, True, False]) # type: ignore[no-untyped-call] with pytest.raises(NotImplementedError): as_compatible_data(masked_array) @@ -39,27 +46,53 @@ def test_as_compatible_data_with_0d_object() -> None: np.array_equal(as_compatible_data(data), data) -def test_as_compatible_data_with_explicitly_indexed(random_inputs) -> None: +def test_as_compatible_data_with_explicitly_indexed( + random_inputs: np.ndarray[Any, Any] +) -> None: # TODO: Make xr.core.indexing.ExplicitlyIndexed pass is_duck_array and remove this test. - class CustomArray(xr.core.indexing.NDArrayMixin): - def __init__(self, array): + class CustomArrayBase(xr.core.indexing.NDArrayMixin): + def __init__(self, array: T_DuckArray) -> None: self.array = array - class CustomArrayIndexable(CustomArray, xr.core.indexing.ExplicitlyIndexed): + @property + def dtype(self) -> np.dtype[np.generic]: + return self.array.dtype + + @property + def shape(self) -> tuple[int, ...]: + return self.array.shape + + @property + def real(self) -> Self: + raise NotImplementedError + + @property + def imag(self) -> Self: + raise NotImplementedError + + def astype(self, dtype: np.typing.DTypeLike) -> Self: + raise NotImplementedError + + class CustomArray(CustomArrayBase): + def __array__(self) -> np.ndarray[Any, np.dtype[np.generic]]: + return np.array(self.array) + + class CustomArrayIndexable(CustomArrayBase, xr.core.indexing.ExplicitlyIndexed): pass array = CustomArray(random_inputs) - output = as_compatible_data(array) + output: CustomArray = as_compatible_data(array) assert isinstance(output, np.ndarray) - array = CustomArrayIndexable(random_inputs) - output = as_compatible_data(array) - assert isinstance(output, CustomArrayIndexable) + array2 = CustomArrayIndexable(random_inputs) + output2: CustomArrayIndexable = as_compatible_data(array2) + assert isinstance(output2, CustomArrayIndexable) def test_properties() -> None: data = 0.5 * np.arange(10).reshape(2, 5) - named_array: NamedArray[np.ndarray] = NamedArray(["x", "y"], data, {"key": "value"}) + named_array: NamedArray[np.ndarray[Any, Any]] + named_array = NamedArray(["x", "y"], data, {"key": "value"}) assert named_array.dims == ("x", "y") assert np.array_equal(named_array.data, data) assert named_array.attrs == {"key": "value"} @@ -71,9 +104,8 @@ def test_properties() -> None: def test_attrs() -> None: - named_array: NamedArray[np.ndarray] = NamedArray( - ["x", "y"], np.arange(10).reshape(2, 5) - ) + named_array: NamedArray[np.ndarray[Any, Any]] + named_array = NamedArray(["x", "y"], np.arange(10).reshape(2, 5)) assert named_array.attrs == {} named_array.attrs["key"] = "value" assert named_array.attrs == {"key": "value"} @@ -81,8 +113,9 @@ def test_attrs() -> None: assert named_array.attrs == {"key": "value2"} -def test_data(random_inputs) -> None: - named_array: NamedArray[np.ndarray] = NamedArray(["x", "y", "z"], random_inputs) +def test_data(random_inputs: np.ndarray[Any, Any]) -> None: + named_array: NamedArray[np.ndarray[Any, Any]] + named_array = NamedArray(["x", "y", "z"], random_inputs) assert np.array_equal(named_array.data, random_inputs) with pytest.raises(ValueError): named_array.data = np.random.random((3, 4)).astype(np.float64) @@ -96,8 +129,9 @@ def test_data(random_inputs) -> None: (np.bytes_("foo"), np.dtype("S3")), ], ) -def test_0d_string(data, dtype: np.typing.DTypeLike) -> None: - named_array: NamedArray[np.ndarray] = NamedArray([], data) +def test_0d_string(data: Any, dtype: np.typing.DTypeLike) -> None: + named_array: NamedArray[np.ndarray[Any, Any]] + named_array = NamedArray([], data) assert named_array.data == data assert named_array.dims == () assert named_array.sizes == {} @@ -108,7 +142,8 @@ def test_0d_string(data, dtype: np.typing.DTypeLike) -> None: def test_0d_object() -> None: - named_array: NamedArray[np.ndarray] = NamedArray([], (10, 12, 12)) + named_array: NamedArray[np.ndarray[Any, Any]] + named_array = NamedArray([], (10, 12, 12)) expected_data = np.empty((), dtype=object) expected_data[()] = (10, 12, 12) assert np.array_equal(named_array.data, expected_data) @@ -122,7 +157,8 @@ def test_0d_object() -> None: def test_0d_datetime() -> None: - named_array: NamedArray[np.ndarray] = NamedArray([], np.datetime64("2000-01-01")) + named_array: NamedArray[np.ndarray[Any, Any]] + named_array = NamedArray([], np.datetime64("2000-01-01")) assert named_array.dtype == np.dtype("datetime64[D]") @@ -140,8 +176,11 @@ def test_0d_datetime() -> None: (np.timedelta64(1, "as"), np.dtype("timedelta64[as]")), ], ) -def test_0d_timedelta(timedelta, expected_dtype: np.dtype) -> None: - named_array: NamedArray[np.ndarray] = NamedArray([], timedelta) +def test_0d_timedelta( + timedelta: np.timedelta64, expected_dtype: np.dtype[np.timedelta64] +) -> None: + named_array: NamedArray[np.ndarray[Any, np.dtype[np.timedelta64]]] + named_array = NamedArray([], timedelta) assert named_array.dtype == expected_dtype assert named_array.data == timedelta @@ -156,8 +195,9 @@ def test_0d_timedelta(timedelta, expected_dtype: np.dtype) -> None: ([], [], ("x",), True), ], ) -def test_dims_setter(dims, data_shape, new_dims, raises: bool) -> None: - named_array: NamedArray[np.ndarray] = NamedArray(dims, np.random.random(data_shape)) +def test_dims_setter(dims: Any, data_shape: Any, new_dims: Any, raises: bool) -> None: + named_array: NamedArray[np.ndarray[Any, Any]] + named_array = NamedArray(dims, np.random.random(data_shape)) assert named_array.dims == tuple(dims) if raises: with pytest.raises(ValueError): From 36fe91786190aa08a6f4ff78d560a91936e37cc2 Mon Sep 17 00:00:00 2001 From: Bart Schilperoort Date: Wed, 4 Oct 2023 17:16:37 +0200 Subject: [PATCH 30/46] Add xarray-regrid to ecosystem.rst (#8270) * Add xarray-regrid to ecosystem.rst * Add xarray-regrid addition to whats-new. --- doc/ecosystem.rst | 1 + doc/whats-new.rst | 3 +++ 2 files changed, 4 insertions(+) diff --git a/doc/ecosystem.rst b/doc/ecosystem.rst index e6e970c6239..fc5ae963a1d 100644 --- a/doc/ecosystem.rst +++ b/doc/ecosystem.rst @@ -41,6 +41,7 @@ Geosciences harmonic wind analysis in Python. - `wradlib `_: An Open Source Library for Weather Radar Data Processing. - `wrf-python `_: A collection of diagnostic and interpolation routines for use with output of the Weather Research and Forecasting (WRF-ARW) Model. +- `xarray-regrid `_: xarray extension for regridding rectilinear data. - `xarray-simlab `_: xarray extension for computer model simulations. - `xarray-spatial `_: Numba-accelerated raster-based spatial processing tools (NDVI, curvature, zonal-statistics, proximity, hillshading, viewshed, etc.) - `xarray-topo `_: xarray extension for topographic analysis and modelling. diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 63c9dee04c5..e73a1a7fa62 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -42,6 +42,9 @@ Bug fixes Documentation ~~~~~~~~~~~~~ +- Added xarray-regrid to the list of xarray related projects (:pull:`8272`). + By `Bart Schilperoort `_. + Internal Changes ~~~~~~~~~~~~~~~~ From 8d54acf463b0fb29fffacd68f460e12477e4900c Mon Sep 17 00:00:00 2001 From: Anderson Banihirwe <13301940+andersy005@users.noreply.github.com> Date: Wed, 4 Oct 2023 11:14:24 -0700 Subject: [PATCH 31/46] copy the `dtypes` module to the `namedarray` package. (#8250) * move dtypes module to namedarray * keep original dtypes * revert utils changes * Update xarray/namedarray/dtypes.py Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix missing imports * update typing * fix return types * type fixes * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * type fixes --------- Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Deepak Cherian --- xarray/namedarray/core.py | 2 +- xarray/namedarray/dtypes.py | 199 ++++++++++++++++++++++++++++++++++++ xarray/namedarray/utils.py | 27 +++++ 3 files changed, 227 insertions(+), 1 deletion(-) create mode 100644 xarray/namedarray/dtypes.py diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index 9b7aff9d067..ec3d8fa171b 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -267,7 +267,7 @@ def imag(self) -> Self: """ return self._replace(data=self.data.imag) - def __dask_tokenize__(self) -> Hashable | None: + def __dask_tokenize__(self) -> Hashable: # Use v.data, instead of v._data, in order to cope with the wrappers # around NetCDF and the like from dask.base import normalize_token diff --git a/xarray/namedarray/dtypes.py b/xarray/namedarray/dtypes.py new file mode 100644 index 00000000000..7a83bd17064 --- /dev/null +++ b/xarray/namedarray/dtypes.py @@ -0,0 +1,199 @@ +from __future__ import annotations + +import functools +import sys +from typing import Any, Literal + +if sys.version_info >= (3, 10): + from typing import TypeGuard +else: + from typing_extensions import TypeGuard + +import numpy as np + +from xarray.namedarray import utils + +# Use as a sentinel value to indicate a dtype appropriate NA value. +NA = utils.ReprObject("") + + +@functools.total_ordering +class AlwaysGreaterThan: + def __gt__(self, other: Any) -> Literal[True]: + return True + + def __eq__(self, other: Any) -> bool: + return isinstance(other, type(self)) + + +@functools.total_ordering +class AlwaysLessThan: + def __lt__(self, other: Any) -> Literal[True]: + return True + + def __eq__(self, other: Any) -> bool: + return isinstance(other, type(self)) + + +# Equivalence to np.inf (-np.inf) for object-type +INF = AlwaysGreaterThan() +NINF = AlwaysLessThan() + + +# Pairs of types that, if both found, should be promoted to object dtype +# instead of following NumPy's own type-promotion rules. These type promotion +# rules match pandas instead. For reference, see the NumPy type hierarchy: +# https://numpy.org/doc/stable/reference/arrays.scalars.html +PROMOTE_TO_OBJECT: tuple[tuple[type[np.generic], type[np.generic]], ...] = ( + (np.number, np.character), # numpy promotes to character + (np.bool_, np.character), # numpy promotes to character + (np.bytes_, np.str_), # numpy promotes to unicode +) + + +def maybe_promote(dtype: np.dtype[np.generic]) -> tuple[np.dtype[np.generic], Any]: + """Simpler equivalent of pandas.core.common._maybe_promote + + Parameters + ---------- + dtype : np.dtype + + Returns + ------- + dtype : Promoted dtype that can hold missing values. + fill_value : Valid missing value for the promoted dtype. + """ + # N.B. these casting rules should match pandas + dtype_: np.typing.DTypeLike + fill_value: Any + if np.issubdtype(dtype, np.floating): + dtype_ = dtype + fill_value = np.nan + elif np.issubdtype(dtype, np.timedelta64): + # See https://github.com/numpy/numpy/issues/10685 + # np.timedelta64 is a subclass of np.integer + # Check np.timedelta64 before np.integer + fill_value = np.timedelta64("NaT") + dtype_ = dtype + elif np.issubdtype(dtype, np.integer): + dtype_ = np.float32 if dtype.itemsize <= 2 else np.float64 + fill_value = np.nan + elif np.issubdtype(dtype, np.complexfloating): + dtype_ = dtype + fill_value = np.nan + np.nan * 1j + elif np.issubdtype(dtype, np.datetime64): + dtype_ = dtype + fill_value = np.datetime64("NaT") + else: + dtype_ = object + fill_value = np.nan + + dtype_out = np.dtype(dtype_) + fill_value = dtype_out.type(fill_value) + return dtype_out, fill_value + + +NAT_TYPES = {np.datetime64("NaT").dtype, np.timedelta64("NaT").dtype} + + +def get_fill_value(dtype: np.dtype[np.generic]) -> Any: + """Return an appropriate fill value for this dtype. + + Parameters + ---------- + dtype : np.dtype + + Returns + ------- + fill_value : Missing value corresponding to this dtype. + """ + _, fill_value = maybe_promote(dtype) + return fill_value + + +def get_pos_infinity( + dtype: np.dtype[np.generic], max_for_int: bool = False +) -> float | complex | AlwaysGreaterThan: + """Return an appropriate positive infinity for this dtype. + + Parameters + ---------- + dtype : np.dtype + max_for_int : bool + Return np.iinfo(dtype).max instead of np.inf + + Returns + ------- + fill_value : positive infinity value corresponding to this dtype. + """ + if issubclass(dtype.type, np.floating): + return np.inf + + if issubclass(dtype.type, np.integer): + return np.iinfo(dtype.type).max if max_for_int else np.inf + if issubclass(dtype.type, np.complexfloating): + return np.inf + 1j * np.inf + + return INF + + +def get_neg_infinity( + dtype: np.dtype[np.generic], min_for_int: bool = False +) -> float | complex | AlwaysLessThan: + """Return an appropriate positive infinity for this dtype. + + Parameters + ---------- + dtype : np.dtype + min_for_int : bool + Return np.iinfo(dtype).min instead of -np.inf + + Returns + ------- + fill_value : positive infinity value corresponding to this dtype. + """ + if issubclass(dtype.type, np.floating): + return -np.inf + + if issubclass(dtype.type, np.integer): + return np.iinfo(dtype.type).min if min_for_int else -np.inf + if issubclass(dtype.type, np.complexfloating): + return -np.inf - 1j * np.inf + + return NINF + + +def is_datetime_like( + dtype: np.dtype[np.generic], +) -> TypeGuard[np.datetime64 | np.timedelta64]: + """Check if a dtype is a subclass of the numpy datetime types""" + return np.issubdtype(dtype, np.datetime64) or np.issubdtype(dtype, np.timedelta64) + + +def result_type( + *arrays_and_dtypes: np.typing.ArrayLike | np.typing.DTypeLike, +) -> np.dtype[np.generic]: + """Like np.result_type, but with type promotion rules matching pandas. + + Examples of changed behavior: + number + string -> object (not string) + bytes + unicode -> object (not unicode) + + Parameters + ---------- + *arrays_and_dtypes : list of arrays and dtypes + The dtype is extracted from both numpy and dask arrays. + + Returns + ------- + numpy.dtype for the result. + """ + types = {np.result_type(t).type for t in arrays_and_dtypes} + + for left, right in PROMOTE_TO_OBJECT: + if any(issubclass(t, left) for t in types) and any( + issubclass(t, right) for t in types + ): + return np.dtype(object) + + return np.result_type(*arrays_and_dtypes) diff --git a/xarray/namedarray/utils.py b/xarray/namedarray/utils.py index c77009aeb2d..6f7658ea00b 100644 --- a/xarray/namedarray/utils.py +++ b/xarray/namedarray/utils.py @@ -2,6 +2,7 @@ import importlib import sys +from collections.abc import Hashable from enum import Enum from typing import TYPE_CHECKING, Any, Final, Protocol, TypeVar @@ -134,3 +135,29 @@ def to_0d_object_array( result = np.empty((), dtype=object) result[()] = value return result + + +class ReprObject: + """Object that prints as the given value, for use with sentinel values.""" + + __slots__ = ("_value",) + + _value: str + + def __init__(self, value: str): + self._value = value + + def __repr__(self) -> str: + return self._value + + def __eq__(self, other: ReprObject | Any) -> bool: + # TODO: What type can other be? ArrayLike? + return self._value == other._value if isinstance(other, ReprObject) else False + + def __hash__(self) -> int: + return hash((type(self), self._value)) + + def __dask_tokenize__(self) -> Hashable: + from dask.base import normalize_token + + return normalize_token((type(self), self._value)) # type: ignore[no-any-return] From 25c76892b2688e54483209e3e34012abcb5d6b1a Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Date: Wed, 4 Oct 2023 12:05:02 -0700 Subject: [PATCH 32/46] Mandate kwargs on `to_zarr` (#8257) * Mandate kwargs on `to_zarr` This aleviates some of the dangers of having these in a different order between `da` & `ds`. _Technically_ it's a breaking change, but only very technically, given that I would wager literally no one has a dozen positional arguments to this method. So I think it's OK. --- doc/whats-new.rst | 4 ++++ xarray/backends/api.py | 2 ++ xarray/core/dataarray.py | 2 ++ xarray/core/dataset.py | 2 ++ 4 files changed, 10 insertions(+) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index e73a1a7fa62..fb1c07f0616 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -30,6 +30,10 @@ New Features Breaking changes ~~~~~~~~~~~~~~~~ +- :py:meth:`Dataset.to_zarr` & :py:meth:`DataArray.to_zarr` require keyword + arguments after the initial 7 positional arguments. + By `Maximilian Roos `_. + Deprecations ~~~~~~~~~~~~ diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 7ca4377e4cf..27e155872de 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -1528,6 +1528,7 @@ def to_zarr( synchronizer=None, group: str | None = None, encoding: Mapping | None = None, + *, compute: Literal[True] = True, consolidated: bool | None = None, append_dim: Hashable | None = None, @@ -1573,6 +1574,7 @@ def to_zarr( synchronizer=None, group: str | None = None, encoding: Mapping | None = None, + *, compute: bool = True, consolidated: bool | None = None, append_dim: Hashable | None = None, diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index ef4389f3c6c..904688d7df9 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -4010,6 +4010,7 @@ def to_zarr( mode: Literal["w", "w-", "a", "r+", None] = None, synchronizer=None, group: str | None = None, + *, encoding: Mapping | None = None, compute: Literal[True] = True, consolidated: bool | None = None, @@ -4050,6 +4051,7 @@ def to_zarr( synchronizer=None, group: str | None = None, encoding: Mapping | None = None, + *, compute: bool = True, consolidated: bool | None = None, append_dim: Hashable | None = None, diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 459e2f3fce7..ef27071eace 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -2293,6 +2293,7 @@ def to_zarr( synchronizer=None, group: str | None = None, encoding: Mapping | None = None, + *, compute: Literal[True] = True, consolidated: bool | None = None, append_dim: Hashable | None = None, @@ -2336,6 +2337,7 @@ def to_zarr( synchronizer=None, group: str | None = None, encoding: Mapping | None = None, + *, compute: bool = True, consolidated: bool | None = None, append_dim: Hashable | None = None, From e09609c234590dffb1b46c8526c3524da561c0ec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Simon=20H=C3=B8xbro=20Hansen?= Date: Wed, 4 Oct 2023 22:24:49 +0200 Subject: [PATCH 33/46] Don't raise rename warning if it is a no operation (#8266) * Don't raise rename warning if it is a no operation * xr.Dataset -> Dataset * Remove pytest.warns * Add whatsnew --- doc/whats-new.rst | 3 +++ xarray/core/dataset.py | 3 +++ xarray/tests/test_dataarray.py | 10 ++++++++++ xarray/tests/test_dataset.py | 12 ++++++++++-- 4 files changed, 26 insertions(+), 2 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index fb1c07f0616..c15bbd4bd7f 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -41,6 +41,9 @@ Deprecations Bug fixes ~~~~~~~~~ +- :py:meth:`DataArray.rename` & :py:meth:`Dataset.rename` would emit a warning + when the operation was a no-op. (:issue:`8266`) + By `Simon Hansen `_. Documentation diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index ef27071eace..bf0daf3c6d4 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -4167,6 +4167,9 @@ def _rename( create_dim_coord = False new_k = name_dict[k] + if k == new_k: + continue # Same name, nothing to do + if k in self.dims and new_k in self._coord_names: coord_dims = self._variables[name_dict[k]].dims if coord_dims == (k,): diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 63175f2be40..d497cd5a54d 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -1883,6 +1883,16 @@ def test_rename_dimension_coord_warnings(self) -> None: ): da.rename(x="y") + # No operation should not raise a warning + da = xr.DataArray( + data=np.ones((2, 3)), + dims=["x", "y"], + coords={"x": range(2), "y": range(3), "a": ("x", [3, 4])}, + ) + with warnings.catch_warnings(): + warnings.simplefilter("error") + da.rename(x="x") + def test_init_value(self) -> None: expected = DataArray( np.full((3, 4), 3), dims=["x", "y"], coords=[range(3), range(4)] diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index ac641c4abc3..08bfeccaac7 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -3032,8 +3032,7 @@ def test_rename_old_name(self) -> None: def test_rename_same_name(self) -> None: data = create_test_data() newnames = {"var1": "var1", "dim2": "dim2"} - with pytest.warns(UserWarning, match="does not create an index anymore"): - renamed = data.rename(newnames) + renamed = data.rename(newnames) assert_identical(renamed, data) def test_rename_dims(self) -> None: @@ -3103,6 +3102,15 @@ def test_rename_dimension_coord_warnings(self) -> None: ): ds.rename(x="y") + # No operation should not raise a warning + ds = Dataset( + data_vars={"data": (("x", "y"), np.ones((2, 3)))}, + coords={"x": range(2), "y": range(3), "a": ("x", [3, 4])}, + ) + with warnings.catch_warnings(): + warnings.simplefilter("error") + ds.rename(x="x") + def test_rename_multiindex(self) -> None: midx = pd.MultiIndex.from_tuples([([1, 2]), ([3, 4])], names=["a", "b"]) midx_coords = Coordinates.from_pandas_multiindex(midx, "x") From bd40c20a5fd025724af5862765eab6bf90eb92f5 Mon Sep 17 00:00:00 2001 From: Tom White Date: Thu, 5 Oct 2023 19:41:19 +0100 Subject: [PATCH 34/46] Use duck array ops in more places (#8267) * Use duck array ops for `reshape` * Use duck array ops for `sum` * Use duck array ops for `astype` * Use duck array ops for `ravel` * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update what's new --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- doc/whats-new.rst | 4 ++++ xarray/core/accessor_dt.py | 13 +++++++------ xarray/core/computation.py | 3 ++- xarray/core/duck_array_ops.py | 6 +++++- xarray/core/nanops.py | 10 +++++++--- xarray/core/variable.py | 2 +- xarray/tests/test_coarsen.py | 12 ++++++++---- xarray/tests/test_variable.py | 2 +- 8 files changed, 35 insertions(+), 17 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index c15bbd4bd7f..ed6b5043ab9 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -56,6 +56,10 @@ Documentation Internal Changes ~~~~~~~~~~~~~~~~ +- More improvements to support the Python `array API standard `_ + by using duck array ops in more places in the codebase. (:pull:`8267`) + By `Tom White `_. + .. _whats-new.2023.09.0: diff --git a/xarray/core/accessor_dt.py b/xarray/core/accessor_dt.py index 4c1ce4b5c48..8255e2a5232 100644 --- a/xarray/core/accessor_dt.py +++ b/xarray/core/accessor_dt.py @@ -7,6 +7,7 @@ import pandas as pd from xarray.coding.times import infer_calendar_name +from xarray.core import duck_array_ops from xarray.core.common import ( _contains_datetime_like_objects, is_np_datetime_like, @@ -50,7 +51,7 @@ def _access_through_cftimeindex(values, name): from xarray.coding.cftimeindex import CFTimeIndex if not isinstance(values, CFTimeIndex): - values_as_cftimeindex = CFTimeIndex(values.ravel()) + values_as_cftimeindex = CFTimeIndex(duck_array_ops.ravel(values)) else: values_as_cftimeindex = values if name == "season": @@ -69,7 +70,7 @@ def _access_through_series(values, name): """Coerce an array of datetime-like values to a pandas Series and access requested datetime component """ - values_as_series = pd.Series(values.ravel(), copy=False) + values_as_series = pd.Series(duck_array_ops.ravel(values), copy=False) if name == "season": months = values_as_series.dt.month.values field_values = _season_from_months(months) @@ -148,10 +149,10 @@ def _round_through_series_or_index(values, name, freq): from xarray.coding.cftimeindex import CFTimeIndex if is_np_datetime_like(values.dtype): - values_as_series = pd.Series(values.ravel(), copy=False) + values_as_series = pd.Series(duck_array_ops.ravel(values), copy=False) method = getattr(values_as_series.dt, name) else: - values_as_cftimeindex = CFTimeIndex(values.ravel()) + values_as_cftimeindex = CFTimeIndex(duck_array_ops.ravel(values)) method = getattr(values_as_cftimeindex, name) field_values = method(freq=freq).values @@ -195,7 +196,7 @@ def _strftime_through_cftimeindex(values, date_format: str): """ from xarray.coding.cftimeindex import CFTimeIndex - values_as_cftimeindex = CFTimeIndex(values.ravel()) + values_as_cftimeindex = CFTimeIndex(duck_array_ops.ravel(values)) field_values = values_as_cftimeindex.strftime(date_format) return field_values.values.reshape(values.shape) @@ -205,7 +206,7 @@ def _strftime_through_series(values, date_format: str): """Coerce an array of datetime-like values to a pandas Series and apply string formatting """ - values_as_series = pd.Series(values.ravel(), copy=False) + values_as_series = pd.Series(duck_array_ops.ravel(values), copy=False) strs = values_as_series.dt.strftime(date_format) return strs.values.reshape(values.shape) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index c707403db97..db786910f22 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -2123,7 +2123,8 @@ def _calc_idxminmax( chunkmanager = get_chunked_array_type(array.data) chunks = dict(zip(array.dims, array.chunks)) dask_coord = chunkmanager.from_array(array[dim].data, chunks=chunks[dim]) - res = indx.copy(data=dask_coord[indx.data.ravel()].reshape(indx.shape)) + data = dask_coord[duck_array_ops.ravel(indx.data)] + res = indx.copy(data=duck_array_ops.reshape(data, indx.shape)) # we need to attach back the dim name res.name = dim else: diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 4f245e59f73..078aab0ed63 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -337,6 +337,10 @@ def reshape(array, shape): return xp.reshape(array, shape) +def ravel(array): + return reshape(array, (-1,)) + + @contextlib.contextmanager def _ignore_warnings_if(condition): if condition: @@ -363,7 +367,7 @@ def f(values, axis=None, skipna=None, **kwargs): values = asarray(values) if coerce_strings and values.dtype.kind in "SU": - values = values.astype(object) + values = astype(values, object) func = None if skipna or (skipna is None and values.dtype.kind in "cfO"): diff --git a/xarray/core/nanops.py b/xarray/core/nanops.py index 3b8ddfe032d..fc7240139aa 100644 --- a/xarray/core/nanops.py +++ b/xarray/core/nanops.py @@ -4,7 +4,7 @@ import numpy as np -from xarray.core import dtypes, nputils, utils +from xarray.core import dtypes, duck_array_ops, nputils, utils from xarray.core.duck_array_ops import ( astype, count, @@ -21,12 +21,16 @@ def _maybe_null_out(result, axis, mask, min_count=1): xarray version of pandas.core.nanops._maybe_null_out """ if axis is not None and getattr(result, "ndim", False): - null_mask = (np.take(mask.shape, axis).prod() - mask.sum(axis) - min_count) < 0 + null_mask = ( + np.take(mask.shape, axis).prod() + - duck_array_ops.sum(mask, axis) + - min_count + ) < 0 dtype, fill_value = dtypes.maybe_promote(result.dtype) result = where(null_mask, fill_value, astype(result, dtype)) elif getattr(result, "dtype", None) not in dtypes.NAT_TYPES: - null_mask = mask.size - mask.sum() + null_mask = mask.size - duck_array_ops.sum(mask) result = where(null_mask < min_count, np.nan, result) return result diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 4eeda073555..3baecfe5f6d 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -2324,7 +2324,7 @@ def coarsen_reshape(self, windows, boundary, side): else: shape.append(variable.shape[i]) - return variable.data.reshape(shape), tuple(axes) + return duck_array_ops.reshape(variable.data, shape), tuple(axes) def isnull(self, keep_attrs: bool | None = None): """Test each value in the array for whether it is a missing value. diff --git a/xarray/tests/test_coarsen.py b/xarray/tests/test_coarsen.py index e345ae691ec..01d5393e289 100644 --- a/xarray/tests/test_coarsen.py +++ b/xarray/tests/test_coarsen.py @@ -6,6 +6,7 @@ import xarray as xr from xarray import DataArray, Dataset, set_options +from xarray.core import duck_array_ops from xarray.tests import ( assert_allclose, assert_equal, @@ -272,21 +273,24 @@ def test_coarsen_construct(self, dask: bool) -> None: expected = xr.Dataset(attrs={"foo": "bar"}) expected["vart"] = ( ("year", "month"), - ds.vart.data.reshape((-1, 12)), + duck_array_ops.reshape(ds.vart.data, (-1, 12)), {"a": "b"}, ) expected["varx"] = ( ("x", "x_reshaped"), - ds.varx.data.reshape((-1, 5)), + duck_array_ops.reshape(ds.varx.data, (-1, 5)), {"a": "b"}, ) expected["vartx"] = ( ("x", "x_reshaped", "year", "month"), - ds.vartx.data.reshape(2, 5, 4, 12), + duck_array_ops.reshape(ds.vartx.data, (2, 5, 4, 12)), {"a": "b"}, ) expected["vary"] = ds.vary - expected.coords["time"] = (("year", "month"), ds.time.data.reshape((-1, 12))) + expected.coords["time"] = ( + ("year", "month"), + duck_array_ops.reshape(ds.time.data, (-1, 12)), + ) with raise_if_dask_computes(): actual = ds.coarsen(time=12, x=5).construct( diff --git a/xarray/tests/test_variable.py b/xarray/tests/test_variable.py index f162b1c7d0a..1ffd51f4a04 100644 --- a/xarray/tests/test_variable.py +++ b/xarray/tests/test_variable.py @@ -916,7 +916,7 @@ def test_pad_constant_values(self, xr_arg, np_arg): actual = v.pad(**xr_arg) expected = np.pad( - np.array(v.data.astype(float)), + np.array(duck_array_ops.astype(v.data, float)), np_arg, mode="constant", constant_values=np.nan, From 938579dbf7360c7d760ee7c6d3ffb2753bfa92e4 Mon Sep 17 00:00:00 2001 From: Mathias Hauser Date: Thu, 5 Oct 2023 22:38:48 +0200 Subject: [PATCH 35/46] make more args kw only (except 'dim') (#6403) * make more args kw only (except 'dim') * add deprecation * add forgotten deprecations * doctest fixes * fix some warnings * remove expand_dims again * undo expand_dims, fix mypy * whats-new entry [skip-ci] * add typing to _deprecate_positional_args helper * Update xarray/util/deprecation_helpers.py Co-authored-by: Michael Niklas * fix kw only for overload * move typing * restore # type: ignore * add type ignores to test_deprecation_helpers --------- Co-authored-by: Michael Niklas Co-authored-by: Deepak Cherian --- doc/whats-new.rst | 2 ++ xarray/core/dataarray.py | 34 +++++++++++++++++++++++- xarray/core/dataset.py | 19 +++++++++++++ xarray/core/groupby.py | 3 +++ xarray/core/weighted.py | 13 +++++++++ xarray/tests/test_dataset.py | 4 +-- xarray/tests/test_deprecation_helpers.py | 30 ++++++++++----------- xarray/util/deprecation_helpers.py | 5 +++- 8 files changed, 91 insertions(+), 19 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index ed6b5043ab9..b3bd372caf7 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -30,6 +30,8 @@ New Features Breaking changes ~~~~~~~~~~~~~~~~ +- Made more arguments keyword-only (e.g. ``keep_attrs``, ``skipna``) for many :py:class:`xarray.DataArray` and + :py:class:`xarray.Dataset` methods (:pull:`6403`). By `Mathias Hauser `_. - :py:meth:`Dataset.to_zarr` & :py:meth:`DataArray.to_zarr` require keyword arguments after the initial 7 positional arguments. By `Maximilian Roos `_. diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 904688d7df9..04c9fb17257 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -66,6 +66,7 @@ ) from xarray.plot.accessor import DataArrayPlotAccessor from xarray.plot.utils import _get_units_from_attrs +from xarray.util.deprecation_helpers import _deprecate_positional_args if TYPE_CHECKING: from typing import TypeVar, Union @@ -954,6 +955,7 @@ def coords(self) -> DataArrayCoordinates: def reset_coords( self, names: Dims = None, + *, drop: Literal[False] = False, ) -> Dataset: ... @@ -967,9 +969,11 @@ def reset_coords( ) -> Self: ... + @_deprecate_positional_args("v2023.10.0") def reset_coords( self, names: Dims = None, + *, drop: bool = False, ) -> Self | Dataset: """Given names of coordinates, reset them to become variables. @@ -1287,9 +1291,11 @@ def chunksizes(self) -> Mapping[Any, tuple[int, ...]]: all_variables = [self.variable] + [c.variable for c in self.coords.values()] return get_chunksizes(all_variables) + @_deprecate_positional_args("v2023.10.0") def chunk( self, chunks: T_Chunks = {}, # {} even though it's technically unsafe, is being used intentionally here (#4667) + *, name_prefix: str = "xarray-", token: str | None = None, lock: bool = False, @@ -1724,9 +1730,11 @@ def thin( ds = self._to_temp_dataset().thin(indexers, **indexers_kwargs) return self._from_temp_dataset(ds) + @_deprecate_positional_args("v2023.10.0") def broadcast_like( self, other: T_DataArrayOrSet, + *, exclude: Iterable[Hashable] | None = None, ) -> Self: """Broadcast this DataArray against another Dataset or DataArray. @@ -1835,9 +1843,11 @@ def _reindex_callback( return da + @_deprecate_positional_args("v2023.10.0") def reindex_like( self, other: T_DataArrayOrSet, + *, method: ReindexMethodOptions = None, tolerance: int | float | Iterable[int | float] | None = None, copy: bool = True, @@ -2005,9 +2015,11 @@ def reindex_like( fill_value=fill_value, ) + @_deprecate_positional_args("v2023.10.0") def reindex( self, indexers: Mapping[Any, Any] | None = None, + *, method: ReindexMethodOptions = None, tolerance: float | Iterable[float] | None = None, copy: bool = True, @@ -2787,9 +2799,11 @@ def stack( ) return self._from_temp_dataset(ds) + @_deprecate_positional_args("v2023.10.0") def unstack( self, dim: Dims = None, + *, fill_value: Any = dtypes.NA, sparse: bool = False, ) -> Self: @@ -2847,7 +2861,7 @@ def unstack( -------- DataArray.stack """ - ds = self._to_temp_dataset().unstack(dim, fill_value, sparse) + ds = self._to_temp_dataset().unstack(dim, fill_value=fill_value, sparse=sparse) return self._from_temp_dataset(ds) def to_unstacked_dataset(self, dim: Hashable, level: int | Hashable = 0) -> Dataset: @@ -3198,9 +3212,11 @@ def drop_isel( dataset = dataset.drop_isel(indexers=indexers, **indexers_kwargs) return self._from_temp_dataset(dataset) + @_deprecate_positional_args("v2023.10.0") def dropna( self, dim: Hashable, + *, how: Literal["any", "all"] = "any", thresh: int | None = None, ) -> Self: @@ -4696,10 +4712,12 @@ def _title_for_slice(self, truncate: int = 50) -> str: return title + @_deprecate_positional_args("v2023.10.0") def diff( self, dim: Hashable, n: int = 1, + *, label: Literal["upper", "lower"] = "upper", ) -> Self: """Calculate the n-th order discrete difference along given axis. @@ -4985,10 +5003,12 @@ def sortby( ds = self._to_temp_dataset().sortby(variables, ascending=ascending) return self._from_temp_dataset(ds) + @_deprecate_positional_args("v2023.10.0") def quantile( self, q: ArrayLike, dim: Dims = None, + *, method: QuantileMethods = "linear", keep_attrs: bool | None = None, skipna: bool | None = None, @@ -5103,9 +5123,11 @@ def quantile( ) return self._from_temp_dataset(ds) + @_deprecate_positional_args("v2023.10.0") def rank( self, dim: Hashable, + *, pct: bool = False, keep_attrs: bool | None = None, ) -> Self: @@ -5678,9 +5700,11 @@ def pad( ) return self._from_temp_dataset(ds) + @_deprecate_positional_args("v2023.10.0") def idxmin( self, dim: Hashable | None = None, + *, skipna: bool | None = None, fill_value: Any = dtypes.NA, keep_attrs: bool | None = None, @@ -5774,9 +5798,11 @@ def idxmin( keep_attrs=keep_attrs, ) + @_deprecate_positional_args("v2023.10.0") def idxmax( self, dim: Hashable = None, + *, skipna: bool | None = None, fill_value: Any = dtypes.NA, keep_attrs: bool | None = None, @@ -5870,9 +5896,11 @@ def idxmax( keep_attrs=keep_attrs, ) + @_deprecate_positional_args("v2023.10.0") def argmin( self, dim: Dims = None, + *, axis: int | None = None, keep_attrs: bool | None = None, skipna: bool | None = None, @@ -5970,9 +5998,11 @@ def argmin( else: return self._replace_maybe_drop_dims(result) + @_deprecate_positional_args("v2023.10.0") def argmax( self, dim: Dims = None, + *, axis: int | None = None, keep_attrs: bool | None = None, skipna: bool | None = None, @@ -6317,9 +6347,11 @@ def curvefit( kwargs=kwargs, ) + @_deprecate_positional_args("v2023.10.0") def drop_duplicates( self, dim: Hashable | Iterable[Hashable], + *, keep: Literal["first", "last", False] = "first", ) -> Self: """Returns a new DataArray with duplicate dimension values removed. diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index bf0daf3c6d4..5f709a5cd63 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -123,6 +123,7 @@ calculate_dimensions, ) from xarray.plot.accessor import DatasetPlotAccessor +from xarray.util.deprecation_helpers import _deprecate_positional_args if TYPE_CHECKING: from numpy.typing import ArrayLike @@ -4775,9 +4776,11 @@ def set_index( variables, coord_names=coord_names, indexes=indexes_ ) + @_deprecate_positional_args("v2023.10.0") def reset_index( self, dims_or_levels: Hashable | Sequence[Hashable], + *, drop: bool = False, ) -> Self: """Reset the specified index(es) or multi-index level(s). @@ -5412,9 +5415,11 @@ def _unstack_full_reindex( variables, coord_names=coord_names, indexes=indexes ) + @_deprecate_positional_args("v2023.10.0") def unstack( self, dim: Dims = None, + *, fill_value: Any = xrdtypes.NA, sparse: bool = False, ) -> Self: @@ -6155,9 +6160,11 @@ def transpose( ds._variables[name] = var.transpose(*var_dims) return ds + @_deprecate_positional_args("v2023.10.0") def dropna( self, dim: Hashable, + *, how: Literal["any", "all"] = "any", thresh: int | None = None, subset: Iterable[Hashable] | None = None, @@ -7583,10 +7590,12 @@ def _copy_attrs_from(self, other): if v in self.variables: self.variables[v].attrs = other.variables[v].attrs + @_deprecate_positional_args("v2023.10.0") def diff( self, dim: Hashable, n: int = 1, + *, label: Literal["upper", "lower"] = "upper", ) -> Self: """Calculate the n-th order discrete difference along given axis. @@ -7913,10 +7922,12 @@ def sortby( indices[key] = order if ascending else order[::-1] return aligned_self.isel(indices) + @_deprecate_positional_args("v2023.10.0") def quantile( self, q: ArrayLike, dim: Dims = None, + *, method: QuantileMethods = "linear", numeric_only: bool = False, keep_attrs: bool | None = None, @@ -8091,9 +8102,11 @@ def quantile( ) return new.assign_coords(quantile=q) + @_deprecate_positional_args("v2023.10.0") def rank( self, dim: Hashable, + *, pct: bool = False, keep_attrs: bool | None = None, ) -> Self: @@ -9037,9 +9050,11 @@ def pad( attrs = self._attrs if keep_attrs else None return self._replace_with_new_dims(variables, indexes=indexes, attrs=attrs) + @_deprecate_positional_args("v2023.10.0") def idxmin( self, dim: Hashable | None = None, + *, skipna: bool | None = None, fill_value: Any = xrdtypes.NA, keep_attrs: bool | None = None, @@ -9134,9 +9149,11 @@ def idxmin( ) ) + @_deprecate_positional_args("v2023.10.0") def idxmax( self, dim: Hashable | None = None, + *, skipna: bool | None = None, fill_value: Any = xrdtypes.NA, keep_attrs: bool | None = None, @@ -9757,9 +9774,11 @@ def _wrapper(Y, *args, **kwargs): return result + @_deprecate_positional_args("v2023.10.0") def drop_duplicates( self, dim: Hashable | Iterable[Hashable], + *, keep: Literal["first", "last", False] = "first", ) -> Self: """Returns a new Dataset with duplicate dimension values removed. diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index e9ddf044568..8ed7148e2a1 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -43,6 +43,7 @@ peek_at, ) from xarray.core.variable import IndexVariable, Variable +from xarray.util.deprecation_helpers import _deprecate_positional_args if TYPE_CHECKING: from numpy.typing import ArrayLike @@ -1092,10 +1093,12 @@ def fillna(self, value: Any) -> T_Xarray: """ return ops.fillna(self, value) + @_deprecate_positional_args("v2023.10.0") def quantile( self, q: ArrayLike, dim: Dims = None, + *, method: QuantileMethods = "linear", keep_attrs: bool | None = None, skipna: bool | None = None, diff --git a/xarray/core/weighted.py b/xarray/core/weighted.py index b1ea1ee625c..28740a99020 100644 --- a/xarray/core/weighted.py +++ b/xarray/core/weighted.py @@ -11,6 +11,7 @@ from xarray.core.computation import apply_ufunc, dot from xarray.core.pycompat import is_duck_dask_array from xarray.core.types import Dims, T_Xarray +from xarray.util.deprecation_helpers import _deprecate_positional_args # Weighted quantile methods are a subset of the numpy supported quantile methods. QUANTILE_METHODS = Literal[ @@ -450,18 +451,22 @@ def _weighted_quantile_1d( def _implementation(self, func, dim, **kwargs): raise NotImplementedError("Use `Dataset.weighted` or `DataArray.weighted`") + @_deprecate_positional_args("v2023.10.0") def sum_of_weights( self, dim: Dims = None, + *, keep_attrs: bool | None = None, ) -> T_Xarray: return self._implementation( self._sum_of_weights, dim=dim, keep_attrs=keep_attrs ) + @_deprecate_positional_args("v2023.10.0") def sum_of_squares( self, dim: Dims = None, + *, skipna: bool | None = None, keep_attrs: bool | None = None, ) -> T_Xarray: @@ -469,9 +474,11 @@ def sum_of_squares( self._sum_of_squares, dim=dim, skipna=skipna, keep_attrs=keep_attrs ) + @_deprecate_positional_args("v2023.10.0") def sum( self, dim: Dims = None, + *, skipna: bool | None = None, keep_attrs: bool | None = None, ) -> T_Xarray: @@ -479,9 +486,11 @@ def sum( self._weighted_sum, dim=dim, skipna=skipna, keep_attrs=keep_attrs ) + @_deprecate_positional_args("v2023.10.0") def mean( self, dim: Dims = None, + *, skipna: bool | None = None, keep_attrs: bool | None = None, ) -> T_Xarray: @@ -489,9 +498,11 @@ def mean( self._weighted_mean, dim=dim, skipna=skipna, keep_attrs=keep_attrs ) + @_deprecate_positional_args("v2023.10.0") def var( self, dim: Dims = None, + *, skipna: bool | None = None, keep_attrs: bool | None = None, ) -> T_Xarray: @@ -499,9 +510,11 @@ def var( self._weighted_var, dim=dim, skipna=skipna, keep_attrs=keep_attrs ) + @_deprecate_positional_args("v2023.10.0") def std( self, dim: Dims = None, + *, skipna: bool | None = None, keep_attrs: bool | None = None, ) -> T_Xarray: diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 08bfeccaac7..3841398ff75 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -5059,9 +5059,9 @@ def test_dropna(self) -> None: ): ds.dropna("foo") with pytest.raises(ValueError, match=r"invalid how"): - ds.dropna("a", how="somehow") # type: ignore + ds.dropna("a", how="somehow") # type: ignore[arg-type] with pytest.raises(TypeError, match=r"must specify how or thresh"): - ds.dropna("a", how=None) # type: ignore + ds.dropna("a", how=None) # type: ignore[arg-type] def test_fillna(self) -> None: ds = Dataset({"a": ("x", [np.nan, 1, np.nan, 3])}, {"x": [0, 1, 2, 3]}) diff --git a/xarray/tests/test_deprecation_helpers.py b/xarray/tests/test_deprecation_helpers.py index 35128829073..f21c8097060 100644 --- a/xarray/tests/test_deprecation_helpers.py +++ b/xarray/tests/test_deprecation_helpers.py @@ -15,15 +15,15 @@ def f1(a, b, *, c="c", d="d"): assert result == (1, 2, 3, 4) with pytest.warns(FutureWarning, match=r".*v0.1"): - result = f1(1, 2, 3) + result = f1(1, 2, 3) # type: ignore[misc] assert result == (1, 2, 3, "d") with pytest.warns(FutureWarning, match=r"Passing 'c' as positional"): - result = f1(1, 2, 3) + result = f1(1, 2, 3) # type: ignore[misc] assert result == (1, 2, 3, "d") with pytest.warns(FutureWarning, match=r"Passing 'c, d' as positional"): - result = f1(1, 2, 3, 4) + result = f1(1, 2, 3, 4) # type: ignore[misc] assert result == (1, 2, 3, 4) @_deprecate_positional_args("v0.1") @@ -31,7 +31,7 @@ def f2(a="a", *, b="b", c="c", d="d"): return a, b, c, d with pytest.warns(FutureWarning, match=r"Passing 'b' as positional"): - result = f2(1, 2) + result = f2(1, 2) # type: ignore[misc] assert result == (1, 2, "c", "d") @_deprecate_positional_args("v0.1") @@ -39,11 +39,11 @@ def f3(a, *, b="b", **kwargs): return a, b, kwargs with pytest.warns(FutureWarning, match=r"Passing 'b' as positional"): - result = f3(1, 2) + result = f3(1, 2) # type: ignore[misc] assert result == (1, 2, {}) with pytest.warns(FutureWarning, match=r"Passing 'b' as positional"): - result = f3(1, 2, f="f") + result = f3(1, 2, f="f") # type: ignore[misc] assert result == (1, 2, {"f": "f"}) @_deprecate_positional_args("v0.1") @@ -57,7 +57,7 @@ def f4(a, /, *, b="b", **kwargs): assert result == (1, 2, {"f": "f"}) with pytest.warns(FutureWarning, match=r"Passing 'b' as positional"): - result = f4(1, 2, f="f") + result = f4(1, 2, f="f") # type: ignore[misc] assert result == (1, 2, {"f": "f"}) with pytest.raises(TypeError, match=r"Keyword-only param without default"): @@ -80,15 +80,15 @@ def method(self, a, b, *, c="c", d="d"): assert result == (1, 2, 3, 4) with pytest.warns(FutureWarning, match=r".*v0.1"): - result = A1().method(1, 2, 3) + result = A1().method(1, 2, 3) # type: ignore[misc] assert result == (1, 2, 3, "d") with pytest.warns(FutureWarning, match=r"Passing 'c' as positional"): - result = A1().method(1, 2, 3) + result = A1().method(1, 2, 3) # type: ignore[misc] assert result == (1, 2, 3, "d") with pytest.warns(FutureWarning, match=r"Passing 'c, d' as positional"): - result = A1().method(1, 2, 3, 4) + result = A1().method(1, 2, 3, 4) # type: ignore[misc] assert result == (1, 2, 3, 4) class A2: @@ -97,11 +97,11 @@ def method(self, a=1, b=1, *, c="c", d="d"): return a, b, c, d with pytest.warns(FutureWarning, match=r"Passing 'c' as positional"): - result = A2().method(1, 2, 3) + result = A2().method(1, 2, 3) # type: ignore[misc] assert result == (1, 2, 3, "d") with pytest.warns(FutureWarning, match=r"Passing 'c, d' as positional"): - result = A2().method(1, 2, 3, 4) + result = A2().method(1, 2, 3, 4) # type: ignore[misc] assert result == (1, 2, 3, 4) class A3: @@ -110,11 +110,11 @@ def method(self, a, *, b="b", **kwargs): return a, b, kwargs with pytest.warns(FutureWarning, match=r"Passing 'b' as positional"): - result = A3().method(1, 2) + result = A3().method(1, 2) # type: ignore[misc] assert result == (1, 2, {}) with pytest.warns(FutureWarning, match=r"Passing 'b' as positional"): - result = A3().method(1, 2, f="f") + result = A3().method(1, 2, f="f") # type: ignore[misc] assert result == (1, 2, {"f": "f"}) class A4: @@ -129,7 +129,7 @@ def method(self, a, /, *, b="b", **kwargs): assert result == (1, 2, {"f": "f"}) with pytest.warns(FutureWarning, match=r"Passing 'b' as positional"): - result = A4().method(1, 2, f="f") + result = A4().method(1, 2, f="f") # type: ignore[misc] assert result == (1, 2, {"f": "f"}) with pytest.raises(TypeError, match=r"Keyword-only param without default"): diff --git a/xarray/util/deprecation_helpers.py b/xarray/util/deprecation_helpers.py index e9681bdf398..7b4cf901aa1 100644 --- a/xarray/util/deprecation_helpers.py +++ b/xarray/util/deprecation_helpers.py @@ -34,6 +34,9 @@ import inspect import warnings from functools import wraps +from typing import Callable, TypeVar + +T = TypeVar("T", bound=Callable) POSITIONAL_OR_KEYWORD = inspect.Parameter.POSITIONAL_OR_KEYWORD KEYWORD_ONLY = inspect.Parameter.KEYWORD_ONLY @@ -41,7 +44,7 @@ EMPTY = inspect.Parameter.empty -def _deprecate_positional_args(version): +def _deprecate_positional_args(version) -> Callable[[T], T]: """Decorator for methods that issues warnings for positional arguments Using the keyword-only argument syntax in pep 3102, arguments after the From 2cd8f96a1b5ae954d7b34390e8b01fbd985fc710 Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Date: Thu, 5 Oct 2023 20:35:21 -0700 Subject: [PATCH 36/46] Allow a function in `.sortby` method (#8273) --- doc/whats-new.rst | 5 ++++- xarray/core/common.py | 3 ++- xarray/core/dataarray.py | 31 +++++++++++++++++++++++-------- xarray/core/dataset.py | 26 ++++++++++++++++++++------ 4 files changed, 49 insertions(+), 16 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index b3bd372caf7..55e0fbaf177 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -23,9 +23,12 @@ New Features ~~~~~~~~~~~~ - :py:meth:`DataArray.where` & :py:meth:`Dataset.where` accept a callable for - the ``other`` parameter, passing the object as the first argument. Previously, + the ``other`` parameter, passing the object as the only argument. Previously, this was only valid for the ``cond`` parameter. (:issue:`8255`) By `Maximilian Roos `_. +- :py:meth:`DataArray.sortby` & :py:meth:`Dataset.sortby` accept a callable for + the ``variables`` parameter, passing the object as the only argument. + By `Maximilian Roos `_. Breaking changes ~~~~~~~~~~~~~~~~ diff --git a/xarray/core/common.py b/xarray/core/common.py index 2a4c4c200d4..f571576850c 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -1073,7 +1073,8 @@ def where(self, cond: Any, other: Any = dtypes.NA, drop: bool = False) -> Self: ---------- cond : DataArray, Dataset, or callable Locations at which to preserve this object's values. dtype must be `bool`. - If a callable, it must expect this object as its only parameter. + If a callable, the callable is passed this object, and the result is used as + the value for cond. other : scalar, DataArray, Dataset, or callable, optional Value to use for locations in this object where ``cond`` is False. By default, these locations are filled with NA. If a callable, it must diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 04c9fb17257..2bcc5ab85e2 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -4941,7 +4941,10 @@ def dot( def sortby( self, - variables: Hashable | DataArray | Sequence[Hashable | DataArray], + variables: Hashable + | DataArray + | Sequence[Hashable | DataArray] + | Callable[[Self], Hashable | DataArray | Sequence[Hashable | DataArray]], ascending: bool = True, ) -> Self: """Sort object by labels or values (along an axis). @@ -4962,9 +4965,10 @@ def sortby( Parameters ---------- - variables : Hashable, DataArray, or sequence of Hashable or DataArray - 1D DataArray objects or name(s) of 1D variable(s) in - coords whose values are used to sort this array. + variables : Hashable, DataArray, sequence of Hashable or DataArray, or Callable + 1D DataArray objects or name(s) of 1D variable(s) in coords whose values are + used to sort this array. If a callable, the callable is passed this object, + and the result is used as the value for cond. ascending : bool, default: True Whether to sort by ascending or descending order. @@ -4984,22 +4988,33 @@ def sortby( Examples -------- >>> da = xr.DataArray( - ... np.random.rand(5), + ... np.arange(5, 0, -1), ... coords=[pd.date_range("1/1/2000", periods=5)], ... dims="time", ... ) >>> da - array([0.5488135 , 0.71518937, 0.60276338, 0.54488318, 0.4236548 ]) + array([5, 4, 3, 2, 1]) Coordinates: * time (time) datetime64[ns] 2000-01-01 2000-01-02 ... 2000-01-05 >>> da.sortby(da) - array([0.4236548 , 0.54488318, 0.5488135 , 0.60276338, 0.71518937]) + array([1, 2, 3, 4, 5]) Coordinates: - * time (time) datetime64[ns] 2000-01-05 2000-01-04 ... 2000-01-02 + * time (time) datetime64[ns] 2000-01-05 2000-01-04 ... 2000-01-01 + + >>> da.sortby(lambda x: x) + + array([1, 2, 3, 4, 5]) + Coordinates: + * time (time) datetime64[ns] 2000-01-05 2000-01-04 ... 2000-01-01 """ + # We need to convert the callable here rather than pass it through to the + # dataset method, since otherwise the dataset method would try to call the + # callable with the dataset as the object + if callable(variables): + variables = variables(self) ds = self._to_temp_dataset().sortby(variables, ascending=ascending) return self._from_temp_dataset(ds) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 5f709a5cd63..e49c981b827 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -7838,7 +7838,10 @@ def roll( def sortby( self, - variables: Hashable | DataArray | list[Hashable | DataArray], + variables: Hashable + | DataArray + | Sequence[Hashable | DataArray] + | Callable[[Self], Hashable | DataArray | list[Hashable | DataArray]], ascending: bool = True, ) -> Self: """ @@ -7860,9 +7863,10 @@ def sortby( Parameters ---------- - variables : Hashable, DataArray, or list of hashable or DataArray - 1D DataArray objects or name(s) of 1D variable(s) in - coords/data_vars whose values are used to sort the dataset. + kariables : Hashable, DataArray, sequence of Hashable or DataArray, or Callable + 1D DataArray objects or name(s) of 1D variable(s) in coords whose values are + used to sort this array. If a callable, the callable is passed this object, + and the result is used as the value for cond. ascending : bool, default: True Whether to sort by ascending or descending order. @@ -7888,8 +7892,7 @@ def sortby( ... }, ... coords={"x": ["b", "a"], "y": [1, 0]}, ... ) - >>> ds = ds.sortby("x") - >>> ds + >>> ds.sortby("x") Dimensions: (x: 2, y: 2) Coordinates: @@ -7898,9 +7901,20 @@ def sortby( Data variables: A (x, y) int64 3 4 1 2 B (x, y) int64 7 8 5 6 + >>> ds.sortby(lambda x: -x["y"]) + + Dimensions: (x: 2, y: 2) + Coordinates: + * x (x) Date: Fri, 6 Oct 2023 10:08:51 -0400 Subject: [PATCH 37/46] Fix time encoding regression (#8272) --- doc/whats-new.rst | 6 ++++++ xarray/coding/times.py | 3 ++- xarray/tests/test_coding_times.py | 11 ++++++++--- 3 files changed, 16 insertions(+), 4 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 55e0fbaf177..92acc3f90c0 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -50,6 +50,12 @@ Bug fixes when the operation was a no-op. (:issue:`8266`) By `Simon Hansen `_. +- Fix datetime encoding precision loss regression introduced in the previous + release for datetimes encoded with units requiring floating point values, and + a reference date not equal to the first value of the datetime array + (:issue:`8271`, :pull:`8272`). By `Spencer Clark + `_. + Documentation ~~~~~~~~~~~~~ diff --git a/xarray/coding/times.py b/xarray/coding/times.py index 2822f02dd8d..039fe371100 100644 --- a/xarray/coding/times.py +++ b/xarray/coding/times.py @@ -714,7 +714,8 @@ def encode_cf_datetime( if data_units != units: # this accounts for differences in the reference times ref_delta = abs(data_ref_date - ref_date).to_timedelta64() - if ref_delta > np.timedelta64(0, "ns"): + data_delta = _time_units_to_timedelta64(needed_units) + if (ref_delta % data_delta) > np.timedelta64(0, "ns"): needed_units = _infer_time_units_from_diff(ref_delta) # needed time delta to encode faithfully to int64 diff --git a/xarray/tests/test_coding_times.py b/xarray/tests/test_coding_times.py index 5f76a4a2ca8..423e48bd155 100644 --- a/xarray/tests/test_coding_times.py +++ b/xarray/tests/test_coding_times.py @@ -1363,10 +1363,15 @@ def test_roundtrip_timedelta64_nanosecond_precision_warning() -> None: def test_roundtrip_float_times() -> None: + # Regression test for GitHub issue #8271 fill_value = 20.0 - times = [np.datetime64("2000-01-01 12:00:00", "ns"), np.datetime64("NaT", "ns")] + times = [ + np.datetime64("1970-01-01 00:00:00", "ns"), + np.datetime64("1970-01-01 06:00:00", "ns"), + np.datetime64("NaT", "ns"), + ] - units = "days since 2000-01-01" + units = "days since 1960-01-01" var = Variable( ["time"], times, @@ -1374,7 +1379,7 @@ def test_roundtrip_float_times() -> None: ) encoded_var = conventions.encode_cf_variable(var) - np.testing.assert_array_equal(encoded_var, np.array([0.5, 20.0])) + np.testing.assert_array_equal(encoded_var, np.array([3653, 3653.25, 20.0])) assert encoded_var.attrs["units"] == units assert encoded_var.attrs["_FillValue"] == fill_value From e8be4bbb961f58ba733852c998f2863f3ff644b1 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 8 Oct 2023 02:49:34 +0200 Subject: [PATCH 38/46] Update ci-additional.yaml (#8280) --- .github/workflows/ci-additional.yaml | 2 -- 1 file changed, 2 deletions(-) diff --git a/.github/workflows/ci-additional.yaml b/.github/workflows/ci-additional.yaml index ec1c192fd35..766937ba761 100644 --- a/.github/workflows/ci-additional.yaml +++ b/.github/workflows/ci-additional.yaml @@ -82,8 +82,6 @@ jobs: name: Mypy runs-on: "ubuntu-latest" needs: detect-ci-trigger - # temporarily skipping due to https://github.com/pydata/xarray/issues/6551 - if: needs.detect-ci-trigger.outputs.triggered == 'false' defaults: run: shell: bash -l {0} From 7aa207b250a50d94410c716e6f624286f9887650 Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Mon, 9 Oct 2023 12:21:39 +0200 Subject: [PATCH 39/46] Improved typing of align & broadcast (#8234) * add overloads to align * add overloads to broadcast as well * add some more typing * remove unused ignore --- xarray/core/alignment.py | 205 +++++++++++++++++++++++++++++++---- xarray/core/common.py | 2 +- xarray/core/computation.py | 20 +++- xarray/core/dataarray.py | 2 +- xarray/core/dataset.py | 8 +- xarray/core/merge.py | 9 +- xarray/core/types.py | 3 +- xarray/tests/test_dataset.py | 4 +- 8 files changed, 216 insertions(+), 37 deletions(-) diff --git a/xarray/core/alignment.py b/xarray/core/alignment.py index ff2ecbc74a1..7d9ba4f4b94 100644 --- a/xarray/core/alignment.py +++ b/xarray/core/alignment.py @@ -5,7 +5,7 @@ from collections import defaultdict from collections.abc import Hashable, Iterable, Mapping from contextlib import suppress -from typing import TYPE_CHECKING, Any, Callable, Generic, cast +from typing import TYPE_CHECKING, Any, Callable, Final, Generic, TypeVar, cast, overload import numpy as np import pandas as pd @@ -26,7 +26,13 @@ if TYPE_CHECKING: from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset - from xarray.core.types import JoinOptions, T_DataArray, T_Dataset, T_DuckArray + from xarray.core.types import ( + Alignable, + JoinOptions, + T_DataArray, + T_Dataset, + T_DuckArray, + ) def reindex_variables( @@ -128,7 +134,7 @@ def __init__( objects: Iterable[T_Alignable], join: str = "inner", indexes: Mapping[Any, Any] | None = None, - exclude_dims: Iterable = frozenset(), + exclude_dims: str | Iterable[Hashable] = frozenset(), exclude_vars: Iterable[Hashable] = frozenset(), method: str | None = None, tolerance: int | float | Iterable[int | float] | None = None, @@ -576,12 +582,111 @@ def align(self) -> None: self.reindex_all() +T_Obj1 = TypeVar("T_Obj1", bound="Alignable") +T_Obj2 = TypeVar("T_Obj2", bound="Alignable") +T_Obj3 = TypeVar("T_Obj3", bound="Alignable") +T_Obj4 = TypeVar("T_Obj4", bound="Alignable") +T_Obj5 = TypeVar("T_Obj5", bound="Alignable") + + +@overload +def align( + obj1: T_Obj1, + /, + *, + join: JoinOptions = "inner", + copy: bool = True, + indexes=None, + exclude: str | Iterable[Hashable] = frozenset(), + fill_value=dtypes.NA, +) -> tuple[T_Obj1]: + ... + + +@overload +def align( # type: ignore[misc] + obj1: T_Obj1, + obj2: T_Obj2, + /, + *, + join: JoinOptions = "inner", + copy: bool = True, + indexes=None, + exclude: str | Iterable[Hashable] = frozenset(), + fill_value=dtypes.NA, +) -> tuple[T_Obj1, T_Obj2]: + ... + + +@overload +def align( # type: ignore[misc] + obj1: T_Obj1, + obj2: T_Obj2, + obj3: T_Obj3, + /, + *, + join: JoinOptions = "inner", + copy: bool = True, + indexes=None, + exclude: str | Iterable[Hashable] = frozenset(), + fill_value=dtypes.NA, +) -> tuple[T_Obj1, T_Obj2, T_Obj3]: + ... + + +@overload +def align( # type: ignore[misc] + obj1: T_Obj1, + obj2: T_Obj2, + obj3: T_Obj3, + obj4: T_Obj4, + /, + *, + join: JoinOptions = "inner", + copy: bool = True, + indexes=None, + exclude: str | Iterable[Hashable] = frozenset(), + fill_value=dtypes.NA, +) -> tuple[T_Obj1, T_Obj2, T_Obj3, T_Obj4]: + ... + + +@overload +def align( # type: ignore[misc] + obj1: T_Obj1, + obj2: T_Obj2, + obj3: T_Obj3, + obj4: T_Obj4, + obj5: T_Obj5, + /, + *, + join: JoinOptions = "inner", + copy: bool = True, + indexes=None, + exclude: str | Iterable[Hashable] = frozenset(), + fill_value=dtypes.NA, +) -> tuple[T_Obj1, T_Obj2, T_Obj3, T_Obj4, T_Obj5]: + ... + + +@overload def align( *objects: T_Alignable, join: JoinOptions = "inner", copy: bool = True, indexes=None, - exclude=frozenset(), + exclude: str | Iterable[Hashable] = frozenset(), + fill_value=dtypes.NA, +) -> tuple[T_Alignable, ...]: + ... + + +def align( # type: ignore[misc] + *objects: T_Alignable, + join: JoinOptions = "inner", + copy: bool = True, + indexes=None, + exclude: str | Iterable[Hashable] = frozenset(), fill_value=dtypes.NA, ) -> tuple[T_Alignable, ...]: """ @@ -620,7 +725,7 @@ def align( indexes : dict-like, optional Any indexes explicitly provided with the `indexes` argument should be used in preference to the aligned indexes. - exclude : sequence of str, optional + exclude : str, iterable of hashable or None, optional Dimensions that must be excluded from alignment fill_value : scalar or dict-like, optional Value to use for newly missing values. If a dict-like, maps @@ -787,12 +892,12 @@ def align( def deep_align( objects: Iterable[Any], join: JoinOptions = "inner", - copy=True, + copy: bool = True, indexes=None, - exclude=frozenset(), - raise_on_invalid=True, + exclude: str | Iterable[Hashable] = frozenset(), + raise_on_invalid: bool = True, fill_value=dtypes.NA, -): +) -> list[Any]: """Align objects for merging, recursing into dictionary values. This function is not public API. @@ -807,12 +912,12 @@ def deep_align( def is_alignable(obj): return isinstance(obj, (Coordinates, DataArray, Dataset)) - positions = [] - keys = [] - out = [] - targets = [] - no_key = object() - not_replaced = object() + positions: list[int] = [] + keys: list[type[object] | Hashable] = [] + out: list[Any] = [] + targets: list[Alignable] = [] + no_key: Final = object() + not_replaced: Final = object() for position, variables in enumerate(objects): if is_alignable(variables): positions.append(position) @@ -857,7 +962,7 @@ def is_alignable(obj): if key is no_key: out[position] = aligned_obj else: - out[position][key] = aligned_obj # type: ignore[index] # maybe someone can fix this? + out[position][key] = aligned_obj return out @@ -988,9 +1093,69 @@ def _broadcast_dataset(ds: T_Dataset) -> T_Dataset: raise ValueError("all input must be Dataset or DataArray objects") -# TODO: this typing is too restrictive since it cannot deal with mixed -# DataArray and Dataset types...? Is this a problem? -def broadcast(*args: T_Alignable, exclude=None) -> tuple[T_Alignable, ...]: +@overload +def broadcast( + obj1: T_Obj1, /, *, exclude: str | Iterable[Hashable] | None = None +) -> tuple[T_Obj1]: + ... + + +@overload +def broadcast( # type: ignore[misc] + obj1: T_Obj1, obj2: T_Obj2, /, *, exclude: str | Iterable[Hashable] | None = None +) -> tuple[T_Obj1, T_Obj2]: + ... + + +@overload +def broadcast( # type: ignore[misc] + obj1: T_Obj1, + obj2: T_Obj2, + obj3: T_Obj3, + /, + *, + exclude: str | Iterable[Hashable] | None = None, +) -> tuple[T_Obj1, T_Obj2, T_Obj3]: + ... + + +@overload +def broadcast( # type: ignore[misc] + obj1: T_Obj1, + obj2: T_Obj2, + obj3: T_Obj3, + obj4: T_Obj4, + /, + *, + exclude: str | Iterable[Hashable] | None = None, +) -> tuple[T_Obj1, T_Obj2, T_Obj3, T_Obj4]: + ... + + +@overload +def broadcast( # type: ignore[misc] + obj1: T_Obj1, + obj2: T_Obj2, + obj3: T_Obj3, + obj4: T_Obj4, + obj5: T_Obj5, + /, + *, + exclude: str | Iterable[Hashable] | None = None, +) -> tuple[T_Obj1, T_Obj2, T_Obj3, T_Obj4, T_Obj5]: + ... + + +@overload +def broadcast( + *args: T_Alignable, exclude: str | Iterable[Hashable] | None = None +) -> tuple[T_Alignable, ...]: + ... + + +def broadcast( # type: ignore[misc] + *args: T_Alignable, exclude: str | Iterable[Hashable] | None = None +) -> tuple[T_Alignable, ...]: """Explicitly broadcast any number of DataArray or Dataset objects against one another. @@ -1004,7 +1169,7 @@ def broadcast(*args: T_Alignable, exclude=None) -> tuple[T_Alignable, ...]: ---------- *args : DataArray or Dataset Arrays to broadcast against each other. - exclude : sequence of str, optional + exclude : str, iterable of hashable or None, optional Dimensions that must not be broadcasted Returns diff --git a/xarray/core/common.py b/xarray/core/common.py index f571576850c..ab8a4d84261 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -1163,7 +1163,7 @@ def where(self, cond: Any, other: Any = dtypes.NA, drop: bool = False) -> Self: f"cond argument is {cond!r} but must be a {Dataset!r} or {DataArray!r} (or a callable than returns one)." ) - self, cond = align(self, cond) # type: ignore[assignment] + self, cond = align(self, cond) def _dataarray_indexer(dim: Hashable) -> DataArray: return cond.any(dim=(d for d in cond.dims if d != dim)) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index db786910f22..9cb60e0c424 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -289,8 +289,14 @@ def apply_dataarray_vfunc( from xarray.core.dataarray import DataArray if len(args) > 1: - args = deep_align( - args, join=join, copy=False, exclude=exclude_dims, raise_on_invalid=False + args = tuple( + deep_align( + args, + join=join, + copy=False, + exclude=exclude_dims, + raise_on_invalid=False, + ) ) objs = _all_of_type(args, DataArray) @@ -506,8 +512,14 @@ def apply_dataset_vfunc( objs = _all_of_type(args, Dataset) if len(args) > 1: - args = deep_align( - args, join=join, copy=False, exclude=exclude_dims, raise_on_invalid=False + args = tuple( + deep_align( + args, + join=join, + copy=False, + exclude=exclude_dims, + raise_on_invalid=False, + ) ) list_of_coords, list_of_indexes = build_output_coords_and_indexes( diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 2bcc5ab85e2..cc5d4a8744c 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -4639,7 +4639,7 @@ def _binary_op( return NotImplemented if isinstance(other, DataArray): align_type = OPTIONS["arithmetic_join"] - self, other = align(self, other, join=align_type, copy=False) # type: ignore[type-var,assignment] + self, other = align(self, other, join=align_type, copy=False) other_variable_or_arraylike: DaCompatible = getattr(other, "variable", other) other_coords = getattr(other, "coords", None) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index e49c981b827..a1faa538564 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -7508,7 +7508,7 @@ def _binary_op(self, other, f, reflexive=False, join=None) -> Dataset: return NotImplemented align_type = OPTIONS["arithmetic_join"] if join is None else join if isinstance(other, (DataArray, Dataset)): - self, other = align(self, other, join=align_type, copy=False) # type: ignore[assignment] + self, other = align(self, other, join=align_type, copy=False) g = f if not reflexive else lambda x, y: f(y, x) ds = self._calculate_binary_op(g, other, join=align_type) keep_attrs = _get_keep_attrs(default=False) @@ -7920,9 +7920,9 @@ def sortby( else: variables = variables arrays = [v if isinstance(v, DataArray) else self[v] for v in variables] - aligned_vars = align(self, *arrays, join="left") # type: ignore[type-var] - aligned_self = cast(Self, aligned_vars[0]) - aligned_other_vars: tuple[DataArray, ...] = aligned_vars[1:] # type: ignore[assignment] + aligned_vars = align(self, *arrays, join="left") + aligned_self = aligned_vars[0] + aligned_other_vars: tuple[DataArray, ...] = aligned_vars[1:] vars_by_dim = defaultdict(list) for data_array in aligned_other_vars: if data_array.ndim != 1: diff --git a/xarray/core/merge.py b/xarray/core/merge.py index 3475db4a010..a8e54ad1231 100644 --- a/xarray/core/merge.py +++ b/xarray/core/merge.py @@ -474,10 +474,11 @@ def coerce_pandas_values(objects: Iterable[CoercibleMapping]) -> list[DatasetLik from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset - out = [] + out: list[DatasetLike] = [] for obj in objects: + variables: DatasetLike if isinstance(obj, (Dataset, Coordinates)): - variables: DatasetLike = obj + variables = obj else: variables = {} if isinstance(obj, PANDAS_TYPES): @@ -491,7 +492,7 @@ def coerce_pandas_values(objects: Iterable[CoercibleMapping]) -> list[DatasetLik def _get_priority_vars_and_indexes( - objects: list[DatasetLike], + objects: Sequence[DatasetLike], priority_arg: int | None, compat: CompatOptions = "equals", ) -> dict[Hashable, MergeElement]: @@ -503,7 +504,7 @@ def _get_priority_vars_and_indexes( Parameters ---------- - objects : list of dict-like of Variable + objects : sequence of dict-like of Variable Dictionaries in which to find the priority variables. priority_arg : int or None Integer object whose variable should take priority. diff --git a/xarray/core/types.py b/xarray/core/types.py index 795283fa88b..2af9591d22a 100644 --- a/xarray/core/types.py +++ b/xarray/core/types.py @@ -145,6 +145,8 @@ def copy( ... +T_Alignable = TypeVar("T_Alignable", bound="Alignable") + T_Backend = TypeVar("T_Backend", bound="BackendEntrypoint") T_Dataset = TypeVar("T_Dataset", bound="Dataset") T_DataArray = TypeVar("T_DataArray", bound="DataArray") @@ -168,7 +170,6 @@ def copy( # on `DataWithCoords`. T_DataWithCoords = TypeVar("T_DataWithCoords", bound="DataWithCoords") -T_Alignable = TypeVar("T_Alignable", bound="Alignable") # Temporary placeholder for indicating an array api compliant type. # hopefully in the future we can narrow this down more: diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 3841398ff75..12347c8d62e 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -2332,9 +2332,9 @@ def test_align(self) -> None: assert np.isnan(left2["var3"][-2:]).all() with pytest.raises(ValueError, match=r"invalid value for join"): - align(left, right, join="foobar") # type: ignore[arg-type] + align(left, right, join="foobar") # type: ignore[call-overload] with pytest.raises(TypeError): - align(left, right, foo="bar") # type: ignore[call-arg] + align(left, right, foo="bar") # type: ignore[call-overload] def test_align_exact(self) -> None: left = xr.Dataset(coords={"x": [0, 1]}) From ab3dd59fa4da78d391c0792715e69e5d06ad89f5 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Mon, 9 Oct 2023 14:32:28 +0200 Subject: [PATCH 40/46] Add pyright type checker (#8279) * Add pyright type checker * Update ci-additional.yaml * Update ci-additional.yaml --- .github/workflows/ci-additional.yaml | 120 +++++++++++++++++++++++++++ pyproject.toml | 25 ++++++ 2 files changed, 145 insertions(+) diff --git a/.github/workflows/ci-additional.yaml b/.github/workflows/ci-additional.yaml index 766937ba761..dc9cc2cd2fe 100644 --- a/.github/workflows/ci-additional.yaml +++ b/.github/workflows/ci-additional.yaml @@ -188,6 +188,126 @@ jobs: + pyright: + name: Pyright + runs-on: "ubuntu-latest" + needs: detect-ci-trigger + if: | + always() + && ( + contains( github.event.pull_request.labels.*.name, 'run-pyright') + ) + defaults: + run: + shell: bash -l {0} + env: + CONDA_ENV_FILE: ci/requirements/environment.yml + PYTHON_VERSION: "3.10" + + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 # Fetch all history for all branches and tags. + + - name: set environment variables + run: | + echo "TODAY=$(date +'%Y-%m-%d')" >> $GITHUB_ENV + - name: Setup micromamba + uses: mamba-org/setup-micromamba@v1 + with: + environment-file: ${{env.CONDA_ENV_FILE}} + environment-name: xarray-tests + create-args: >- + python=${{env.PYTHON_VERSION}} + conda + cache-environment: true + cache-environment-key: "${{runner.os}}-${{runner.arch}}-py${{env.PYTHON_VERSION}}-${{env.TODAY}}-${{hashFiles(env.CONDA_ENV_FILE)}}" + - name: Install xarray + run: | + python -m pip install --no-deps -e . + - name: Version info + run: | + conda info -a + conda list + python xarray/util/print_versions.py + - name: Install pyright + run: | + python -m pip install pyright --force-reinstall + + - name: Run pyright + run: | + python -m pyright xarray/ + + - name: Upload pyright coverage to Codecov + uses: codecov/codecov-action@v3.1.4 + with: + file: pyright_report/cobertura.xml + flags: pyright + env_vars: PYTHON_VERSION + name: codecov-umbrella + fail_ci_if_error: false + + pyright39: + name: Pyright 3.9 + runs-on: "ubuntu-latest" + needs: detect-ci-trigger + if: | + always() + && ( + contains( github.event.pull_request.labels.*.name, 'run-pyright') + ) + defaults: + run: + shell: bash -l {0} + env: + CONDA_ENV_FILE: ci/requirements/environment.yml + PYTHON_VERSION: "3.9" + + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 # Fetch all history for all branches and tags. + + - name: set environment variables + run: | + echo "TODAY=$(date +'%Y-%m-%d')" >> $GITHUB_ENV + - name: Setup micromamba + uses: mamba-org/setup-micromamba@v1 + with: + environment-file: ${{env.CONDA_ENV_FILE}} + environment-name: xarray-tests + create-args: >- + python=${{env.PYTHON_VERSION}} + conda + cache-environment: true + cache-environment-key: "${{runner.os}}-${{runner.arch}}-py${{env.PYTHON_VERSION}}-${{env.TODAY}}-${{hashFiles(env.CONDA_ENV_FILE)}}" + - name: Install xarray + run: | + python -m pip install --no-deps -e . + - name: Version info + run: | + conda info -a + conda list + python xarray/util/print_versions.py + - name: Install pyright + run: | + python -m pip install pyright --force-reinstall + + - name: Run pyright + run: | + python -m pyright xarray/ + + - name: Upload pyright coverage to Codecov + uses: codecov/codecov-action@v3.1.4 + with: + file: pyright_report/cobertura.xml + flags: pyright39 + env_vars: PYTHON_VERSION + name: codecov-umbrella + fail_ci_if_error: false + + + min-version-policy: name: Minimum Version Policy runs-on: "ubuntu-latest" diff --git a/pyproject.toml b/pyproject.toml index e24f88d9679..1a24a4b4eda 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -198,6 +198,31 @@ warn_return_any = true module = ["xarray.namedarray.*", "xarray.tests.test_namedarray"] +[tool.pyright] +# include = ["src"] +# exclude = ["**/node_modules", + # "**/__pycache__", + # "src/experimental", + # "src/typestubs" +# ] +# ignore = ["src/oldstuff"] +defineConstant = { DEBUG = true } +# stubPath = "src/stubs" +# venv = "env367" + +reportMissingImports = true +reportMissingTypeStubs = false + +# pythonVersion = "3.6" +# pythonPlatform = "Linux" + +# executionEnvironments = [ + # { root = "src/web", pythonVersion = "3.5", pythonPlatform = "Windows", extraPaths = [ "src/service_libs" ] }, + # { root = "src/sdk", pythonVersion = "3.0", extraPaths = [ "src/backend" ] }, + # { root = "src/tests", extraPaths = ["src/tests/e2e", "src/sdk" ]}, + # { root = "src" } +# ] + [tool.ruff] builtins = ["ellipsis"] exclude = [ From 129c4ac408d73684fe2cc0682111a2b60b9fbac8 Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Date: Mon, 9 Oct 2023 06:30:03 -0700 Subject: [PATCH 41/46] Ask bug reporters to confirm they're using a recent version of xarray (#8283) --- .github/ISSUE_TEMPLATE/bugreport.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/ISSUE_TEMPLATE/bugreport.yml b/.github/ISSUE_TEMPLATE/bugreport.yml index 59e5889f5ec..cc1a2e12be3 100644 --- a/.github/ISSUE_TEMPLATE/bugreport.yml +++ b/.github/ISSUE_TEMPLATE/bugreport.yml @@ -44,6 +44,7 @@ body: - label: Complete example — the example is self-contained, including all data and the text of any traceback. - label: Verifiable example — the example copy & pastes into an IPython prompt or [Binder notebook](https://mybinder.org/v2/gh/pydata/xarray/main?urlpath=lab/tree/doc/examples/blank_template.ipynb), returning the result. - label: New issue — a search of GitHub Issues suggests this is not a duplicate. + - label: Recent environment — the issue occurs with the latest version of xarray and its dependencies. - type: textarea id: log-output From 46643bb1a4bdbf5cd5c584a907825dda9daa9001 Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Date: Mon, 9 Oct 2023 06:38:44 -0700 Subject: [PATCH 42/46] Fix `GroupBy` import (#8286) --- xarray/tests/test_units.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index d89a74e4fba..7e1105e2e5d 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -305,11 +305,13 @@ def __call__(self, obj, *args, **kwargs): all_args = merge_args(self.args, args) all_kwargs = {**self.kwargs, **kwargs} + from xarray.core.groupby import GroupBy + xarray_classes = ( xr.Variable, xr.DataArray, xr.Dataset, - xr.core.groupby.GroupBy, + GroupBy, ) if not isinstance(obj, xarray_classes): From 75af56c33a29529269a73bdd00df2d3af17ee0f5 Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Date: Mon, 9 Oct 2023 23:37:19 -0700 Subject: [PATCH 43/46] Enable `.rolling_exp` to work on dask arrays (#8284) --- doc/whats-new.rst | 3 +++ xarray/core/rolling_exp.py | 4 ++-- xarray/tests/test_rolling.py | 4 +++- 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 92acc3f90c0..8f576f486dc 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -29,6 +29,9 @@ New Features - :py:meth:`DataArray.sortby` & :py:meth:`Dataset.sortby` accept a callable for the ``variables`` parameter, passing the object as the only argument. By `Maximilian Roos `_. +- ``.rolling_exp`` functions can now operate on dask-backed arrays, assuming the + core dim has exactly one chunk. (:pull:`8284`). + By `Maximilian Roos `_. Breaking changes ~~~~~~~~~~~~~~~~ diff --git a/xarray/core/rolling_exp.py b/xarray/core/rolling_exp.py index c56bf6a384e..cb77358869c 100644 --- a/xarray/core/rolling_exp.py +++ b/xarray/core/rolling_exp.py @@ -147,9 +147,9 @@ def mean(self, keep_attrs: bool | None = None) -> T_DataWithCoords: input_core_dims=[[self.dim]], kwargs=dict(alpha=self.alpha, axis=-1), output_core_dims=[[self.dim]], - exclude_dims={self.dim}, keep_attrs=keep_attrs, on_missing_core_dim="copy", + dask="parallelized", ).transpose(*dim_order) def sum(self, keep_attrs: bool | None = None) -> T_DataWithCoords: @@ -183,7 +183,7 @@ def sum(self, keep_attrs: bool | None = None) -> T_DataWithCoords: input_core_dims=[[self.dim]], kwargs=dict(alpha=self.alpha, axis=-1), output_core_dims=[[self.dim]], - exclude_dims={self.dim}, keep_attrs=keep_attrs, on_missing_core_dim="copy", + dask="parallelized", ).transpose(*dim_order) diff --git a/xarray/tests/test_rolling.py b/xarray/tests/test_rolling.py index 2dc8ae24438..da834b76124 100644 --- a/xarray/tests/test_rolling.py +++ b/xarray/tests/test_rolling.py @@ -788,7 +788,9 @@ def test_raise_no_warning_dask_rolling_assert_close(self, ds, name) -> None: @requires_numbagg class TestDatasetRollingExp: - @pytest.mark.parametrize("backend", ["numpy"], indirect=True) + @pytest.mark.parametrize( + "backend", ["numpy", pytest.param("dask", marks=requires_dask)], indirect=True + ) def test_rolling_exp(self, ds) -> None: result = ds.rolling_exp(time=10, window_type="span").mean() assert isinstance(result, Dataset) From d50a5e5122d71078bcf311d54beed652506080cb Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Date: Thu, 12 Oct 2023 10:11:03 -0700 Subject: [PATCH 44/46] Rename `reset_encoding` to `drop_encoding` (#8287) * Rename `reset_encoding` to `drop_encoding` Closes #8259 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update dataarray.py Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update xarray/core/dataarray.py Co-authored-by: Michael Niklas * Update xarray/core/variable.py Co-authored-by: Michael Niklas * Update xarray/core/dataset.py Co-authored-by: Michael Niklas * api --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com> Co-authored-by: Michael Niklas --- doc/api-hidden.rst | 2 +- doc/api.rst | 4 ++-- doc/user-guide/io.rst | 6 +++--- doc/whats-new.rst | 7 ++++++- xarray/core/dataarray.py | 8 +++++++- xarray/core/dataset.py | 8 +++++++- xarray/core/variable.py | 6 ++++++ xarray/tests/test_dataarray.py | 4 ++-- xarray/tests/test_dataset.py | 4 ++-- xarray/tests/test_variable.py | 6 +++--- 10 files changed, 39 insertions(+), 16 deletions(-) diff --git a/doc/api-hidden.rst b/doc/api-hidden.rst index d97c4010528..552d11a06dc 100644 --- a/doc/api-hidden.rst +++ b/doc/api-hidden.rst @@ -265,7 +265,7 @@ Variable.dims Variable.dtype Variable.encoding - Variable.reset_encoding + Variable.drop_encoding Variable.imag Variable.nbytes Variable.ndim diff --git a/doc/api.rst b/doc/api.rst index 0cf07f91df8..96b4864804f 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -110,9 +110,9 @@ Dataset contents Dataset.drop_indexes Dataset.drop_duplicates Dataset.drop_dims + Dataset.drop_encoding Dataset.set_coords Dataset.reset_coords - Dataset.reset_encoding Dataset.convert_calendar Dataset.interp_calendar Dataset.get_index @@ -303,8 +303,8 @@ DataArray contents DataArray.drop_vars DataArray.drop_indexes DataArray.drop_duplicates + DataArray.drop_encoding DataArray.reset_coords - DataArray.reset_encoding DataArray.copy DataArray.convert_calendar DataArray.interp_calendar diff --git a/doc/user-guide/io.rst b/doc/user-guide/io.rst index 2ffc25b2009..ffded682035 100644 --- a/doc/user-guide/io.rst +++ b/doc/user-guide/io.rst @@ -260,12 +260,12 @@ Note that all operations that manipulate variables other than indexing will remove encoding information. In some cases it is useful to intentionally reset a dataset's original encoding values. -This can be done with either the :py:meth:`Dataset.reset_encoding` or -:py:meth:`DataArray.reset_encoding` methods. +This can be done with either the :py:meth:`Dataset.drop_encoding` or +:py:meth:`DataArray.drop_encoding` methods. .. ipython:: python - ds_no_encoding = ds_disk.reset_encoding() + ds_no_encoding = ds_disk.drop_encoding() ds_no_encoding.encoding .. _combining multiple files: diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 8f576f486dc..40c50e158ad 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -45,7 +45,12 @@ Breaking changes Deprecations ~~~~~~~~~~~~ - +- Rename :py:meth:`Dataset.reset_encoding` & :py:meth:`DataArray.reset_encoding` + to :py:meth:`Dataset.drop_encoding` & :py:meth:`DataArray.drop_encoding` for + consistency with other ``drop`` & ``reset`` methods — ``drop`` generally + removes something, while ``reset`` generally resets to some default or + standard value. (:pull:`8287`, :issue:`8259`) + By `Maximilian Roos `_. Bug fixes ~~~~~~~~~ diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index cc5d4a8744c..391b4ed9412 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -914,9 +914,15 @@ def encoding(self, value: Mapping[Any, Any]) -> None: self.variable.encoding = dict(value) def reset_encoding(self) -> Self: + warnings.warn( + "reset_encoding is deprecated since 2023.11, use `drop_encoding` instead" + ) + return self.drop_encoding() + + def drop_encoding(self) -> Self: """Return a new DataArray without encoding on the array or any attached coords.""" - ds = self._to_temp_dataset().reset_encoding() + ds = self._to_temp_dataset().drop_encoding() return self._from_temp_dataset(ds) @property diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index a1faa538564..ebd6fb6f51f 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -756,9 +756,15 @@ def encoding(self, value: Mapping[Any, Any]) -> None: self._encoding = dict(value) def reset_encoding(self) -> Self: + warnings.warn( + "reset_encoding is deprecated since 2023.11, use `drop_encoding` instead" + ) + return self.drop_encoding() + + def drop_encoding(self) -> Self: """Return a new Dataset without encoding on the dataset or any of its variables/coords.""" - variables = {k: v.reset_encoding() for k, v in self.variables.items()} + variables = {k: v.drop_encoding() for k, v in self.variables.items()} return self._replace(variables=variables, encoding={}) @property diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 3baecfe5f6d..fa5523b1340 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -883,6 +883,12 @@ def encoding(self, value): raise ValueError("encoding must be castable to a dictionary") def reset_encoding(self) -> Self: + warnings.warn( + "reset_encoding is deprecated since 2023.11, use `drop_encoding` instead" + ) + return self.drop_encoding() + + def drop_encoding(self) -> Self: """Return a new Variable without encoding.""" return self._replace(encoding={}) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index d497cd5a54d..5eb5394d58e 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -287,7 +287,7 @@ def test_encoding(self) -> None: self.dv.encoding = expected2 assert expected2 is not self.dv.encoding - def test_reset_encoding(self) -> None: + def test_drop_encoding(self) -> None: array = self.mda encoding = {"scale_factor": 10} array.encoding = encoding @@ -296,7 +296,7 @@ def test_reset_encoding(self) -> None: assert array.encoding == encoding assert array["x"].encoding == encoding - actual = array.reset_encoding() + actual = array.drop_encoding() # did not modify in place assert array.encoding == encoding diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 12347c8d62e..687aae8f1dc 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -2959,7 +2959,7 @@ def test_copy_with_data_errors(self) -> None: with pytest.raises(ValueError, match=r"contain all variables in original"): orig.copy(data={"var1": new_var1}) - def test_reset_encoding(self) -> None: + def test_drop_encoding(self) -> None: orig = create_test_data() vencoding = {"scale_factor": 10} orig.encoding = {"foo": "bar"} @@ -2967,7 +2967,7 @@ def test_reset_encoding(self) -> None: for k, v in orig.variables.items(): orig[k].encoding = vencoding - actual = orig.reset_encoding() + actual = orig.drop_encoding() assert actual.encoding == {} for k, v in actual.variables.items(): assert v.encoding == {} diff --git a/xarray/tests/test_variable.py b/xarray/tests/test_variable.py index 1ffd51f4a04..73238b6ae3a 100644 --- a/xarray/tests/test_variable.py +++ b/xarray/tests/test_variable.py @@ -473,12 +473,12 @@ def test_encoding_preserved(self): assert_identical(expected.to_base_variable(), actual.to_base_variable()) assert expected.encoding == actual.encoding - def test_reset_encoding(self) -> None: + def test_drop_encoding(self) -> None: encoding1 = {"scale_factor": 1} # encoding set via cls constructor v1 = self.cls(["a"], [0, 1, 2], encoding=encoding1) assert v1.encoding == encoding1 - v2 = v1.reset_encoding() + v2 = v1.drop_encoding() assert v1.encoding == encoding1 assert v2.encoding == {} @@ -486,7 +486,7 @@ def test_reset_encoding(self) -> None: encoding3 = {"scale_factor": 10} v3 = self.cls(["a"], [0, 1, 2], encoding=encoding3) assert v3.encoding == encoding3 - v4 = v3.reset_encoding() + v4 = v3.drop_encoding() assert v3.encoding == encoding3 assert v4.encoding == {} From 25e6e084aa18c49d92934db298a3efff9c712766 Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Date: Thu, 12 Oct 2023 12:06:12 -0700 Subject: [PATCH 45/46] Most of mypy 1.6.0 passing (#8296) * Most of mypy 1.6.0 passing * fix typed ops --------- Co-authored-by: Michael Niklas --- pyproject.toml | 29 +++++++-------- xarray/core/_typed_ops.py | 74 ++++++++++++++++++------------------- xarray/core/alignment.py | 16 ++++---- xarray/util/generate_ops.py | 14 ++++--- 4 files changed, 67 insertions(+), 66 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 1a24a4b4eda..bdae33e4d0d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -78,6 +78,7 @@ files = "xarray" show_error_codes = true show_error_context = true warn_redundant_casts = true +warn_unused_configs = true warn_unused_ignores = true # Much of the numerical computing stack doesn't have type annotations yet. @@ -168,26 +169,24 @@ module = [ # ref: https://mypy.readthedocs.io/en/stable/existing_code.html#introduce-stricter-options [[tool.mypy.overrides]] # Start off with these -warn_unused_configs = true -warn_redundant_casts = true warn_unused_ignores = true # Getting these passing should be easy -strict_equality = true strict_concatenate = true +strict_equality = true # Strongly recommend enabling this one as soon as you can check_untyped_defs = true # These shouldn't be too much additional work, but may be tricky to # get passing if you use a lot of untyped libraries +disallow_any_generics = true disallow_subclassing_any = true disallow_untyped_decorators = true -disallow_any_generics = true # These next few are various gradations of forcing use of type annotations -disallow_untyped_calls = true disallow_incomplete_defs = true +disallow_untyped_calls = true disallow_untyped_defs = true # This one isn't too hard to get passing, but return on investment is lower @@ -201,12 +200,12 @@ module = ["xarray.namedarray.*", "xarray.tests.test_namedarray"] [tool.pyright] # include = ["src"] # exclude = ["**/node_modules", - # "**/__pycache__", - # "src/experimental", - # "src/typestubs" +# "**/__pycache__", +# "src/experimental", +# "src/typestubs" # ] # ignore = ["src/oldstuff"] -defineConstant = { DEBUG = true } +defineConstant = {DEBUG = true} # stubPath = "src/stubs" # venv = "env367" @@ -217,10 +216,10 @@ reportMissingTypeStubs = false # pythonPlatform = "Linux" # executionEnvironments = [ - # { root = "src/web", pythonVersion = "3.5", pythonPlatform = "Windows", extraPaths = [ "src/service_libs" ] }, - # { root = "src/sdk", pythonVersion = "3.0", extraPaths = [ "src/backend" ] }, - # { root = "src/tests", extraPaths = ["src/tests/e2e", "src/sdk" ]}, - # { root = "src" } +# { root = "src/web", pythonVersion = "3.5", pythonPlatform = "Windows", extraPaths = [ "src/service_libs" ] }, +# { root = "src/sdk", pythonVersion = "3.0", extraPaths = [ "src/backend" ] }, +# { root = "src/tests", extraPaths = ["src/tests/e2e", "src/sdk" ]}, +# { root = "src" } # ] [tool.ruff] @@ -252,16 +251,16 @@ known-first-party = ["xarray"] [tool.pytest.ini_options] addopts = ["--strict-config", "--strict-markers"] -log_cli_level = "INFO" -minversion = "7" filterwarnings = [ "ignore:Using a non-tuple sequence for multidimensional indexing is deprecated:FutureWarning", ] +log_cli_level = "INFO" markers = [ "flaky: flaky tests", "network: tests requiring a network connection", "slow: slow tests", ] +minversion = "7" python_files = "test_*.py" testpaths = ["xarray/tests", "properties"] diff --git a/xarray/core/_typed_ops.py b/xarray/core/_typed_ops.py index 330d13bb217..9b79ed46a9c 100644 --- a/xarray/core/_typed_ops.py +++ b/xarray/core/_typed_ops.py @@ -4,7 +4,7 @@ from __future__ import annotations import operator -from typing import TYPE_CHECKING, Any, Callable, NoReturn, overload +from typing import TYPE_CHECKING, Any, Callable, overload from xarray.core import nputils, ops from xarray.core.types import ( @@ -446,201 +446,201 @@ def _binary_op( raise NotImplementedError @overload - def __add__(self, other: T_DataArray) -> NoReturn: + def __add__(self, other: T_DataArray) -> T_DataArray: ... @overload def __add__(self, other: VarCompatible) -> Self: ... - def __add__(self, other: VarCompatible) -> Self: + def __add__(self, other: VarCompatible) -> Self | T_DataArray: return self._binary_op(other, operator.add) @overload - def __sub__(self, other: T_DataArray) -> NoReturn: + def __sub__(self, other: T_DataArray) -> T_DataArray: ... @overload def __sub__(self, other: VarCompatible) -> Self: ... - def __sub__(self, other: VarCompatible) -> Self: + def __sub__(self, other: VarCompatible) -> Self | T_DataArray: return self._binary_op(other, operator.sub) @overload - def __mul__(self, other: T_DataArray) -> NoReturn: + def __mul__(self, other: T_DataArray) -> T_DataArray: ... @overload def __mul__(self, other: VarCompatible) -> Self: ... - def __mul__(self, other: VarCompatible) -> Self: + def __mul__(self, other: VarCompatible) -> Self | T_DataArray: return self._binary_op(other, operator.mul) @overload - def __pow__(self, other: T_DataArray) -> NoReturn: + def __pow__(self, other: T_DataArray) -> T_DataArray: ... @overload def __pow__(self, other: VarCompatible) -> Self: ... - def __pow__(self, other: VarCompatible) -> Self: + def __pow__(self, other: VarCompatible) -> Self | T_DataArray: return self._binary_op(other, operator.pow) @overload - def __truediv__(self, other: T_DataArray) -> NoReturn: + def __truediv__(self, other: T_DataArray) -> T_DataArray: ... @overload def __truediv__(self, other: VarCompatible) -> Self: ... - def __truediv__(self, other: VarCompatible) -> Self: + def __truediv__(self, other: VarCompatible) -> Self | T_DataArray: return self._binary_op(other, operator.truediv) @overload - def __floordiv__(self, other: T_DataArray) -> NoReturn: + def __floordiv__(self, other: T_DataArray) -> T_DataArray: ... @overload def __floordiv__(self, other: VarCompatible) -> Self: ... - def __floordiv__(self, other: VarCompatible) -> Self: + def __floordiv__(self, other: VarCompatible) -> Self | T_DataArray: return self._binary_op(other, operator.floordiv) @overload - def __mod__(self, other: T_DataArray) -> NoReturn: + def __mod__(self, other: T_DataArray) -> T_DataArray: ... @overload def __mod__(self, other: VarCompatible) -> Self: ... - def __mod__(self, other: VarCompatible) -> Self: + def __mod__(self, other: VarCompatible) -> Self | T_DataArray: return self._binary_op(other, operator.mod) @overload - def __and__(self, other: T_DataArray) -> NoReturn: + def __and__(self, other: T_DataArray) -> T_DataArray: ... @overload def __and__(self, other: VarCompatible) -> Self: ... - def __and__(self, other: VarCompatible) -> Self: + def __and__(self, other: VarCompatible) -> Self | T_DataArray: return self._binary_op(other, operator.and_) @overload - def __xor__(self, other: T_DataArray) -> NoReturn: + def __xor__(self, other: T_DataArray) -> T_DataArray: ... @overload def __xor__(self, other: VarCompatible) -> Self: ... - def __xor__(self, other: VarCompatible) -> Self: + def __xor__(self, other: VarCompatible) -> Self | T_DataArray: return self._binary_op(other, operator.xor) @overload - def __or__(self, other: T_DataArray) -> NoReturn: + def __or__(self, other: T_DataArray) -> T_DataArray: ... @overload def __or__(self, other: VarCompatible) -> Self: ... - def __or__(self, other: VarCompatible) -> Self: + def __or__(self, other: VarCompatible) -> Self | T_DataArray: return self._binary_op(other, operator.or_) @overload - def __lshift__(self, other: T_DataArray) -> NoReturn: + def __lshift__(self, other: T_DataArray) -> T_DataArray: ... @overload def __lshift__(self, other: VarCompatible) -> Self: ... - def __lshift__(self, other: VarCompatible) -> Self: + def __lshift__(self, other: VarCompatible) -> Self | T_DataArray: return self._binary_op(other, operator.lshift) @overload - def __rshift__(self, other: T_DataArray) -> NoReturn: + def __rshift__(self, other: T_DataArray) -> T_DataArray: ... @overload def __rshift__(self, other: VarCompatible) -> Self: ... - def __rshift__(self, other: VarCompatible) -> Self: + def __rshift__(self, other: VarCompatible) -> Self | T_DataArray: return self._binary_op(other, operator.rshift) @overload - def __lt__(self, other: T_DataArray) -> NoReturn: + def __lt__(self, other: T_DataArray) -> T_DataArray: ... @overload def __lt__(self, other: VarCompatible) -> Self: ... - def __lt__(self, other: VarCompatible) -> Self: + def __lt__(self, other: VarCompatible) -> Self | T_DataArray: return self._binary_op(other, operator.lt) @overload - def __le__(self, other: T_DataArray) -> NoReturn: + def __le__(self, other: T_DataArray) -> T_DataArray: ... @overload def __le__(self, other: VarCompatible) -> Self: ... - def __le__(self, other: VarCompatible) -> Self: + def __le__(self, other: VarCompatible) -> Self | T_DataArray: return self._binary_op(other, operator.le) @overload - def __gt__(self, other: T_DataArray) -> NoReturn: + def __gt__(self, other: T_DataArray) -> T_DataArray: ... @overload def __gt__(self, other: VarCompatible) -> Self: ... - def __gt__(self, other: VarCompatible) -> Self: + def __gt__(self, other: VarCompatible) -> Self | T_DataArray: return self._binary_op(other, operator.gt) @overload - def __ge__(self, other: T_DataArray) -> NoReturn: + def __ge__(self, other: T_DataArray) -> T_DataArray: ... @overload def __ge__(self, other: VarCompatible) -> Self: ... - def __ge__(self, other: VarCompatible) -> Self: + def __ge__(self, other: VarCompatible) -> Self | T_DataArray: return self._binary_op(other, operator.ge) @overload # type:ignore[override] - def __eq__(self, other: T_DataArray) -> NoReturn: + def __eq__(self, other: T_DataArray) -> T_DataArray: ... @overload def __eq__(self, other: VarCompatible) -> Self: ... - def __eq__(self, other: VarCompatible) -> Self: + def __eq__(self, other: VarCompatible) -> Self | T_DataArray: return self._binary_op(other, nputils.array_eq) @overload # type:ignore[override] - def __ne__(self, other: T_DataArray) -> NoReturn: + def __ne__(self, other: T_DataArray) -> T_DataArray: ... @overload def __ne__(self, other: VarCompatible) -> Self: ... - def __ne__(self, other: VarCompatible) -> Self: + def __ne__(self, other: VarCompatible) -> Self | T_DataArray: return self._binary_op(other, nputils.array_ne) def __radd__(self, other: VarCompatible) -> Self: diff --git a/xarray/core/alignment.py b/xarray/core/alignment.py index 7d9ba4f4b94..732ec5d3ea6 100644 --- a/xarray/core/alignment.py +++ b/xarray/core/alignment.py @@ -604,7 +604,7 @@ def align( @overload -def align( # type: ignore[misc] +def align( obj1: T_Obj1, obj2: T_Obj2, /, @@ -619,7 +619,7 @@ def align( # type: ignore[misc] @overload -def align( # type: ignore[misc] +def align( obj1: T_Obj1, obj2: T_Obj2, obj3: T_Obj3, @@ -635,7 +635,7 @@ def align( # type: ignore[misc] @overload -def align( # type: ignore[misc] +def align( obj1: T_Obj1, obj2: T_Obj2, obj3: T_Obj3, @@ -652,7 +652,7 @@ def align( # type: ignore[misc] @overload -def align( # type: ignore[misc] +def align( obj1: T_Obj1, obj2: T_Obj2, obj3: T_Obj3, @@ -1101,14 +1101,14 @@ def broadcast( @overload -def broadcast( # type: ignore[misc] +def broadcast( obj1: T_Obj1, obj2: T_Obj2, /, *, exclude: str | Iterable[Hashable] | None = None ) -> tuple[T_Obj1, T_Obj2]: ... @overload -def broadcast( # type: ignore[misc] +def broadcast( obj1: T_Obj1, obj2: T_Obj2, obj3: T_Obj3, @@ -1120,7 +1120,7 @@ def broadcast( # type: ignore[misc] @overload -def broadcast( # type: ignore[misc] +def broadcast( obj1: T_Obj1, obj2: T_Obj2, obj3: T_Obj3, @@ -1133,7 +1133,7 @@ def broadcast( # type: ignore[misc] @overload -def broadcast( # type: ignore[misc] +def broadcast( obj1: T_Obj1, obj2: T_Obj2, obj3: T_Obj3, diff --git a/xarray/util/generate_ops.py b/xarray/util/generate_ops.py index 632ca06d295..f339470884a 100644 --- a/xarray/util/generate_ops.py +++ b/xarray/util/generate_ops.py @@ -87,13 +87,15 @@ def {method}(self, other: {other_type}) -> {return_type}:{type_ignore} return self._binary_op(other, {func})""" template_binop_overload = """ @overload{overload_type_ignore} - def {method}(self, other: {overload_type}) -> NoReturn: + def {method}(self, other: {overload_type}) -> {overload_type}: ... @overload def {method}(self, other: {other_type}) -> {return_type}: ... -""" + + def {method}(self, other: {other_type}) -> {return_type} | {overload_type}:{type_ignore} + return self._binary_op(other, {func})""" template_reflexive = """ def {method}(self, other: {other_type}) -> {return_type}: return self._binary_op(other, {func}, reflexive=True)""" @@ -123,7 +125,7 @@ def {method}(self, *args: Any, **kwargs: Any) -> Self: # # We require a "hack" to tell type checkers that e.g. Variable + DataArray = DataArray # In reality this returns NotImplementes, but this is not a valid type in python 3.9. -# Therefore, we use NoReturn which mypy seems to recognise! +# Therefore, we return DataArray. In reality this would call DataArray.__add__(Variable) # TODO: change once python 3.10 is the minimum. # # Mypy seems to require that __iadd__ and __add__ have the same signature. @@ -165,7 +167,7 @@ def binops_overload( ([(None, None)], required_method_binary, extras), ( BINOPS_NUM + BINOPS_CMP, - template_binop_overload + template_binop, + template_binop_overload, extras | { "overload_type": overload_type, @@ -175,7 +177,7 @@ def binops_overload( ), ( BINOPS_EQNE, - template_binop_overload + template_binop, + template_binop_overload, extras | { "overload_type": overload_type, @@ -233,7 +235,7 @@ def unops() -> list[OpsType]: from __future__ import annotations import operator -from typing import TYPE_CHECKING, Any, Callable, NoReturn, overload +from typing import TYPE_CHECKING, Any, Callable, overload from xarray.core import nputils, ops from xarray.core.types import ( From 338fc9268a1ea41dbb861906fdf2e79d939e2df3 Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Date: Thu, 12 Oct 2023 15:00:47 -0700 Subject: [PATCH 46/46] xfail flaky test (#8299) * xfail flaky test Would be better to fix it, but in lieu of fixing, better to skip it * . --- xarray/tests/test_backends.py | 1 + 1 file changed, 1 insertion(+) diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 0cbf3af3664..9ec67bf47dc 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -3459,6 +3459,7 @@ def skip_if_not_engine(engine): @requires_dask @pytest.mark.filterwarnings("ignore:use make_scale(name) instead") +@pytest.mark.xfail(reason="Flaky test. Very open to contributions on fixing this") def test_open_mfdataset_manyfiles( readengine, nfiles, parallel, chunks, file_cache_maxsize ):