diff --git a/doc/whats-new.rst b/doc/whats-new.rst
index 63c9dee04c5..cdefa7f89e1 100644
--- a/doc/whats-new.rst
+++ b/doc/whats-new.rst
@@ -46,6 +46,10 @@ Documentation
Internal Changes
~~~~~~~~~~~~~~~~
+- More improvements to support the Python `array API standard `_
+ by using duck array ops in more places in the codebase. (:pull:`8267`)
+ By `Tom White `_.
+
.. _whats-new.2023.09.0:
diff --git a/xarray/core/accessor_dt.py b/xarray/core/accessor_dt.py
index 4c1ce4b5c48..8255e2a5232 100644
--- a/xarray/core/accessor_dt.py
+++ b/xarray/core/accessor_dt.py
@@ -7,6 +7,7 @@
import pandas as pd
from xarray.coding.times import infer_calendar_name
+from xarray.core import duck_array_ops
from xarray.core.common import (
_contains_datetime_like_objects,
is_np_datetime_like,
@@ -50,7 +51,7 @@ def _access_through_cftimeindex(values, name):
from xarray.coding.cftimeindex import CFTimeIndex
if not isinstance(values, CFTimeIndex):
- values_as_cftimeindex = CFTimeIndex(values.ravel())
+ values_as_cftimeindex = CFTimeIndex(duck_array_ops.ravel(values))
else:
values_as_cftimeindex = values
if name == "season":
@@ -69,7 +70,7 @@ def _access_through_series(values, name):
"""Coerce an array of datetime-like values to a pandas Series and
access requested datetime component
"""
- values_as_series = pd.Series(values.ravel(), copy=False)
+ values_as_series = pd.Series(duck_array_ops.ravel(values), copy=False)
if name == "season":
months = values_as_series.dt.month.values
field_values = _season_from_months(months)
@@ -148,10 +149,10 @@ def _round_through_series_or_index(values, name, freq):
from xarray.coding.cftimeindex import CFTimeIndex
if is_np_datetime_like(values.dtype):
- values_as_series = pd.Series(values.ravel(), copy=False)
+ values_as_series = pd.Series(duck_array_ops.ravel(values), copy=False)
method = getattr(values_as_series.dt, name)
else:
- values_as_cftimeindex = CFTimeIndex(values.ravel())
+ values_as_cftimeindex = CFTimeIndex(duck_array_ops.ravel(values))
method = getattr(values_as_cftimeindex, name)
field_values = method(freq=freq).values
@@ -195,7 +196,7 @@ def _strftime_through_cftimeindex(values, date_format: str):
"""
from xarray.coding.cftimeindex import CFTimeIndex
- values_as_cftimeindex = CFTimeIndex(values.ravel())
+ values_as_cftimeindex = CFTimeIndex(duck_array_ops.ravel(values))
field_values = values_as_cftimeindex.strftime(date_format)
return field_values.values.reshape(values.shape)
@@ -205,7 +206,7 @@ def _strftime_through_series(values, date_format: str):
"""Coerce an array of datetime-like values to a pandas Series and
apply string formatting
"""
- values_as_series = pd.Series(values.ravel(), copy=False)
+ values_as_series = pd.Series(duck_array_ops.ravel(values), copy=False)
strs = values_as_series.dt.strftime(date_format)
return strs.values.reshape(values.shape)
diff --git a/xarray/core/computation.py b/xarray/core/computation.py
index c707403db97..db786910f22 100644
--- a/xarray/core/computation.py
+++ b/xarray/core/computation.py
@@ -2123,7 +2123,8 @@ def _calc_idxminmax(
chunkmanager = get_chunked_array_type(array.data)
chunks = dict(zip(array.dims, array.chunks))
dask_coord = chunkmanager.from_array(array[dim].data, chunks=chunks[dim])
- res = indx.copy(data=dask_coord[indx.data.ravel()].reshape(indx.shape))
+ data = dask_coord[duck_array_ops.ravel(indx.data)]
+ res = indx.copy(data=duck_array_ops.reshape(data, indx.shape))
# we need to attach back the dim name
res.name = dim
else:
diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py
index 4f245e59f73..078aab0ed63 100644
--- a/xarray/core/duck_array_ops.py
+++ b/xarray/core/duck_array_ops.py
@@ -337,6 +337,10 @@ def reshape(array, shape):
return xp.reshape(array, shape)
+def ravel(array):
+ return reshape(array, (-1,))
+
+
@contextlib.contextmanager
def _ignore_warnings_if(condition):
if condition:
@@ -363,7 +367,7 @@ def f(values, axis=None, skipna=None, **kwargs):
values = asarray(values)
if coerce_strings and values.dtype.kind in "SU":
- values = values.astype(object)
+ values = astype(values, object)
func = None
if skipna or (skipna is None and values.dtype.kind in "cfO"):
diff --git a/xarray/core/nanops.py b/xarray/core/nanops.py
index 3b8ddfe032d..fc7240139aa 100644
--- a/xarray/core/nanops.py
+++ b/xarray/core/nanops.py
@@ -4,7 +4,7 @@
import numpy as np
-from xarray.core import dtypes, nputils, utils
+from xarray.core import dtypes, duck_array_ops, nputils, utils
from xarray.core.duck_array_ops import (
astype,
count,
@@ -21,12 +21,16 @@ def _maybe_null_out(result, axis, mask, min_count=1):
xarray version of pandas.core.nanops._maybe_null_out
"""
if axis is not None and getattr(result, "ndim", False):
- null_mask = (np.take(mask.shape, axis).prod() - mask.sum(axis) - min_count) < 0
+ null_mask = (
+ np.take(mask.shape, axis).prod()
+ - duck_array_ops.sum(mask, axis)
+ - min_count
+ ) < 0
dtype, fill_value = dtypes.maybe_promote(result.dtype)
result = where(null_mask, fill_value, astype(result, dtype))
elif getattr(result, "dtype", None) not in dtypes.NAT_TYPES:
- null_mask = mask.size - mask.sum()
+ null_mask = mask.size - duck_array_ops.sum(mask)
result = where(null_mask < min_count, np.nan, result)
return result
diff --git a/xarray/core/variable.py b/xarray/core/variable.py
index 4eeda073555..3baecfe5f6d 100644
--- a/xarray/core/variable.py
+++ b/xarray/core/variable.py
@@ -2324,7 +2324,7 @@ def coarsen_reshape(self, windows, boundary, side):
else:
shape.append(variable.shape[i])
- return variable.data.reshape(shape), tuple(axes)
+ return duck_array_ops.reshape(variable.data, shape), tuple(axes)
def isnull(self, keep_attrs: bool | None = None):
"""Test each value in the array for whether it is a missing value.
diff --git a/xarray/tests/test_coarsen.py b/xarray/tests/test_coarsen.py
index e345ae691ec..01d5393e289 100644
--- a/xarray/tests/test_coarsen.py
+++ b/xarray/tests/test_coarsen.py
@@ -6,6 +6,7 @@
import xarray as xr
from xarray import DataArray, Dataset, set_options
+from xarray.core import duck_array_ops
from xarray.tests import (
assert_allclose,
assert_equal,
@@ -272,21 +273,24 @@ def test_coarsen_construct(self, dask: bool) -> None:
expected = xr.Dataset(attrs={"foo": "bar"})
expected["vart"] = (
("year", "month"),
- ds.vart.data.reshape((-1, 12)),
+ duck_array_ops.reshape(ds.vart.data, (-1, 12)),
{"a": "b"},
)
expected["varx"] = (
("x", "x_reshaped"),
- ds.varx.data.reshape((-1, 5)),
+ duck_array_ops.reshape(ds.varx.data, (-1, 5)),
{"a": "b"},
)
expected["vartx"] = (
("x", "x_reshaped", "year", "month"),
- ds.vartx.data.reshape(2, 5, 4, 12),
+ duck_array_ops.reshape(ds.vartx.data, (2, 5, 4, 12)),
{"a": "b"},
)
expected["vary"] = ds.vary
- expected.coords["time"] = (("year", "month"), ds.time.data.reshape((-1, 12)))
+ expected.coords["time"] = (
+ ("year", "month"),
+ duck_array_ops.reshape(ds.time.data, (-1, 12)),
+ )
with raise_if_dask_computes():
actual = ds.coarsen(time=12, x=5).construct(
diff --git a/xarray/tests/test_variable.py b/xarray/tests/test_variable.py
index f162b1c7d0a..1ffd51f4a04 100644
--- a/xarray/tests/test_variable.py
+++ b/xarray/tests/test_variable.py
@@ -916,7 +916,7 @@ def test_pad_constant_values(self, xr_arg, np_arg):
actual = v.pad(**xr_arg)
expected = np.pad(
- np.array(v.data.astype(float)),
+ np.array(duck_array_ops.astype(v.data, float)),
np_arg,
mode="constant",
constant_values=np.nan,