Skip to content

Commit

Permalink
Use duck array ops in more places (#8267)
Browse files Browse the repository at this point in the history
* Use duck array ops for `reshape`

* Use duck array ops for `sum`

* Use duck array ops for `astype`

* Use duck array ops for `ravel`

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update what's new

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
tomwhite and pre-commit-ci[bot] authored Oct 5, 2023
1 parent e09609c commit bd40c20
Show file tree
Hide file tree
Showing 8 changed files with 35 additions and 17 deletions.
4 changes: 4 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ Documentation
Internal Changes
~~~~~~~~~~~~~~~~

- More improvements to support the Python `array API standard <https://data-apis.org/array-api/latest/>`_
by using duck array ops in more places in the codebase. (:pull:`8267`)
By `Tom White <https://github.com/tomwhite>`_.


.. _whats-new.2023.09.0:

Expand Down
13 changes: 7 additions & 6 deletions xarray/core/accessor_dt.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pandas as pd

from xarray.coding.times import infer_calendar_name
from xarray.core import duck_array_ops
from xarray.core.common import (
_contains_datetime_like_objects,
is_np_datetime_like,
Expand Down Expand Up @@ -50,7 +51,7 @@ def _access_through_cftimeindex(values, name):
from xarray.coding.cftimeindex import CFTimeIndex

if not isinstance(values, CFTimeIndex):
values_as_cftimeindex = CFTimeIndex(values.ravel())
values_as_cftimeindex = CFTimeIndex(duck_array_ops.ravel(values))
else:
values_as_cftimeindex = values
if name == "season":
Expand All @@ -69,7 +70,7 @@ def _access_through_series(values, name):
"""Coerce an array of datetime-like values to a pandas Series and
access requested datetime component
"""
values_as_series = pd.Series(values.ravel(), copy=False)
values_as_series = pd.Series(duck_array_ops.ravel(values), copy=False)
if name == "season":
months = values_as_series.dt.month.values
field_values = _season_from_months(months)
Expand Down Expand Up @@ -148,10 +149,10 @@ def _round_through_series_or_index(values, name, freq):
from xarray.coding.cftimeindex import CFTimeIndex

if is_np_datetime_like(values.dtype):
values_as_series = pd.Series(values.ravel(), copy=False)
values_as_series = pd.Series(duck_array_ops.ravel(values), copy=False)
method = getattr(values_as_series.dt, name)
else:
values_as_cftimeindex = CFTimeIndex(values.ravel())
values_as_cftimeindex = CFTimeIndex(duck_array_ops.ravel(values))
method = getattr(values_as_cftimeindex, name)

field_values = method(freq=freq).values
Expand Down Expand Up @@ -195,7 +196,7 @@ def _strftime_through_cftimeindex(values, date_format: str):
"""
from xarray.coding.cftimeindex import CFTimeIndex

values_as_cftimeindex = CFTimeIndex(values.ravel())
values_as_cftimeindex = CFTimeIndex(duck_array_ops.ravel(values))

field_values = values_as_cftimeindex.strftime(date_format)
return field_values.values.reshape(values.shape)
Expand All @@ -205,7 +206,7 @@ def _strftime_through_series(values, date_format: str):
"""Coerce an array of datetime-like values to a pandas Series and
apply string formatting
"""
values_as_series = pd.Series(values.ravel(), copy=False)
values_as_series = pd.Series(duck_array_ops.ravel(values), copy=False)
strs = values_as_series.dt.strftime(date_format)
return strs.values.reshape(values.shape)

Expand Down
3 changes: 2 additions & 1 deletion xarray/core/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2123,7 +2123,8 @@ def _calc_idxminmax(
chunkmanager = get_chunked_array_type(array.data)
chunks = dict(zip(array.dims, array.chunks))
dask_coord = chunkmanager.from_array(array[dim].data, chunks=chunks[dim])
res = indx.copy(data=dask_coord[indx.data.ravel()].reshape(indx.shape))
data = dask_coord[duck_array_ops.ravel(indx.data)]
res = indx.copy(data=duck_array_ops.reshape(data, indx.shape))
# we need to attach back the dim name
res.name = dim
else:
Expand Down
6 changes: 5 additions & 1 deletion xarray/core/duck_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,10 @@ def reshape(array, shape):
return xp.reshape(array, shape)


def ravel(array):
return reshape(array, (-1,))


@contextlib.contextmanager
def _ignore_warnings_if(condition):
if condition:
Expand All @@ -363,7 +367,7 @@ def f(values, axis=None, skipna=None, **kwargs):
values = asarray(values)

if coerce_strings and values.dtype.kind in "SU":
values = values.astype(object)
values = astype(values, object)

func = None
if skipna or (skipna is None and values.dtype.kind in "cfO"):
Expand Down
10 changes: 7 additions & 3 deletions xarray/core/nanops.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import numpy as np

from xarray.core import dtypes, nputils, utils
from xarray.core import dtypes, duck_array_ops, nputils, utils
from xarray.core.duck_array_ops import (
astype,
count,
Expand All @@ -21,12 +21,16 @@ def _maybe_null_out(result, axis, mask, min_count=1):
xarray version of pandas.core.nanops._maybe_null_out
"""
if axis is not None and getattr(result, "ndim", False):
null_mask = (np.take(mask.shape, axis).prod() - mask.sum(axis) - min_count) < 0
null_mask = (
np.take(mask.shape, axis).prod()
- duck_array_ops.sum(mask, axis)
- min_count
) < 0
dtype, fill_value = dtypes.maybe_promote(result.dtype)
result = where(null_mask, fill_value, astype(result, dtype))

elif getattr(result, "dtype", None) not in dtypes.NAT_TYPES:
null_mask = mask.size - mask.sum()
null_mask = mask.size - duck_array_ops.sum(mask)
result = where(null_mask < min_count, np.nan, result)

return result
Expand Down
2 changes: 1 addition & 1 deletion xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -2324,7 +2324,7 @@ def coarsen_reshape(self, windows, boundary, side):
else:
shape.append(variable.shape[i])

return variable.data.reshape(shape), tuple(axes)
return duck_array_ops.reshape(variable.data, shape), tuple(axes)

def isnull(self, keep_attrs: bool | None = None):
"""Test each value in the array for whether it is a missing value.
Expand Down
12 changes: 8 additions & 4 deletions xarray/tests/test_coarsen.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import xarray as xr
from xarray import DataArray, Dataset, set_options
from xarray.core import duck_array_ops
from xarray.tests import (
assert_allclose,
assert_equal,
Expand Down Expand Up @@ -272,21 +273,24 @@ def test_coarsen_construct(self, dask: bool) -> None:
expected = xr.Dataset(attrs={"foo": "bar"})
expected["vart"] = (
("year", "month"),
ds.vart.data.reshape((-1, 12)),
duck_array_ops.reshape(ds.vart.data, (-1, 12)),
{"a": "b"},
)
expected["varx"] = (
("x", "x_reshaped"),
ds.varx.data.reshape((-1, 5)),
duck_array_ops.reshape(ds.varx.data, (-1, 5)),
{"a": "b"},
)
expected["vartx"] = (
("x", "x_reshaped", "year", "month"),
ds.vartx.data.reshape(2, 5, 4, 12),
duck_array_ops.reshape(ds.vartx.data, (2, 5, 4, 12)),
{"a": "b"},
)
expected["vary"] = ds.vary
expected.coords["time"] = (("year", "month"), ds.time.data.reshape((-1, 12)))
expected.coords["time"] = (
("year", "month"),
duck_array_ops.reshape(ds.time.data, (-1, 12)),
)

with raise_if_dask_computes():
actual = ds.coarsen(time=12, x=5).construct(
Expand Down
2 changes: 1 addition & 1 deletion xarray/tests/test_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -916,7 +916,7 @@ def test_pad_constant_values(self, xr_arg, np_arg):

actual = v.pad(**xr_arg)
expected = np.pad(
np.array(v.data.astype(float)),
np.array(duck_array_ops.astype(v.data, float)),
np_arg,
mode="constant",
constant_values=np.nan,
Expand Down

0 comments on commit bd40c20

Please sign in to comment.