Skip to content
forked from pydata/xarray

Commit

Permalink
Merge remote-tracking branch 'upstream/main' into kvikio-2
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Sep 20, 2023
2 parents e43e720 + 99f8446 commit fe87c21
Show file tree
Hide file tree
Showing 9 changed files with 117 additions and 50 deletions.
6 changes: 3 additions & 3 deletions xarray/core/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
if TYPE_CHECKING:
from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset
from xarray.core.types import JoinOptions, T_DataArray, T_Dataset
from xarray.core.types import JoinOptions, T_DataArray, T_Dataset, T_DuckArray


def reindex_variables(
Expand Down Expand Up @@ -173,7 +173,7 @@ def __init__(

def _normalize_indexes(
self,
indexes: Mapping[Any, Any],
indexes: Mapping[Any, Any | T_DuckArray],
) -> tuple[NormalizedIndexes, NormalizedIndexVars]:
"""Normalize the indexes/indexers used for re-indexing or alignment.
Expand All @@ -194,7 +194,7 @@ def _normalize_indexes(
f"Indexer has dimensions {idx.dims} that are different "
f"from that to be indexed along '{k}'"
)
data = as_compatible_data(idx)
data: T_DuckArray = as_compatible_data(idx)
pd_idx = safe_cast_to_index(data)
pd_idx.name = k
if isinstance(pd_idx, pd.MultiIndex):
Expand Down
2 changes: 1 addition & 1 deletion xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7481,7 +7481,7 @@ def _unary_op(self: T_Dataset, f, *args, **kwargs) -> T_Dataset:
else:
variables[k] = f(v, *args, **kwargs)
if keep_attrs:
variables[k].attrs = v._attrs
variables[k]._attrs = v._attrs
attrs = self._attrs if keep_attrs else None
return self._replace_with_new_dims(variables, attrs=attrs)

Expand Down
4 changes: 2 additions & 2 deletions xarray/core/parallelcompat.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
T_ChunkedArray = TypeVar("T_ChunkedArray")

if TYPE_CHECKING:
from xarray.core.types import T_Chunks, T_NormalizedChunks
from xarray.core.types import T_Chunks, T_DuckArray, T_NormalizedChunks


@functools.lru_cache(maxsize=1)
Expand Down Expand Up @@ -257,7 +257,7 @@ def normalize_chunks(

@abstractmethod
def from_array(
self, data: np.ndarray, chunks: T_Chunks, **kwargs
self, data: T_DuckArray | np.typing.ArrayLike, chunks: T_Chunks, **kwargs
) -> T_ChunkedArray:
"""
Create a chunked array from a non-chunked numpy-like array.
Expand Down
4 changes: 4 additions & 0 deletions xarray/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,10 @@ def copy(
T_DataWithCoords = TypeVar("T_DataWithCoords", bound="DataWithCoords")
T_Alignable = TypeVar("T_Alignable", bound="Alignable")

# Temporary placeholder for indicating an array api compliant type.
# hopefully in the future we can narrow this down more:
T_DuckArray = TypeVar("T_DuckArray", bound=Any)

ScalarOrArray = Union["ArrayLike", np.generic, np.ndarray, "DaskArray"]
DsCompatible = Union["Dataset", "DataArray", "Variable", "GroupBy", "ScalarOrArray"]
DaCompatible = Union["DataArray", "Variable", "DataArrayGroupBy", "ScalarOrArray"]
Expand Down
4 changes: 2 additions & 2 deletions xarray/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
import pandas as pd

if TYPE_CHECKING:
from xarray.core.types import Dims, ErrorOptionsWithWarn, OrderedDims
from xarray.core.types import Dims, ErrorOptionsWithWarn, OrderedDims, T_DuckArray

K = TypeVar("K")
V = TypeVar("V")
Expand Down Expand Up @@ -253,7 +253,7 @@ def is_list_like(value: Any) -> TypeGuard[list | tuple]:
return isinstance(value, (list, tuple))


def is_duck_array(value: Any) -> bool:
def is_duck_array(value: Any) -> TypeGuard[T_DuckArray]:
if isinstance(value, np.ndarray):
return True
return (
Expand Down
76 changes: 45 additions & 31 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from collections.abc import Hashable, Iterable, Mapping, Sequence
from datetime import timedelta
from functools import partial
from typing import TYPE_CHECKING, Any, Callable, Literal, NoReturn
from typing import TYPE_CHECKING, Any, Callable, Literal, NoReturn, cast

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -66,6 +66,7 @@
PadModeOptions,
PadReflectOptions,
QuantileMethods,
T_DuckArray,
T_Variable,
)

Expand All @@ -86,7 +87,7 @@ class MissingDimensionsError(ValueError):
# TODO: move this to an xarray.exceptions module?


def as_variable(obj, name=None) -> Variable | IndexVariable:
def as_variable(obj: T_DuckArray | Any, name=None) -> Variable | IndexVariable:
"""Convert an object into a Variable.
Parameters
Expand Down Expand Up @@ -142,7 +143,7 @@ def as_variable(obj, name=None) -> Variable | IndexVariable:
elif isinstance(obj, (set, dict)):
raise TypeError(f"variable {name!r} has invalid type {type(obj)!r}")
elif name is not None:
data = as_compatible_data(obj)
data: T_DuckArray = as_compatible_data(obj)
if data.ndim != 1:
raise MissingDimensionsError(
f"cannot set variable {name!r} with {data.ndim!r}-dimensional data "
Expand Down Expand Up @@ -230,7 +231,9 @@ def _possibly_convert_datetime_or_timedelta_index(data):
return data


def as_compatible_data(data, fastpath: bool = False):
def as_compatible_data(
data: T_DuckArray | ArrayLike, fastpath: bool = False
) -> T_DuckArray:
"""Prepare and wrap data to put in a Variable.
- If data does not have the necessary attributes, convert it to ndarray.
Expand All @@ -243,7 +246,7 @@ def as_compatible_data(data, fastpath: bool = False):
"""
if fastpath and getattr(data, "ndim", 0) > 0:
# can't use fastpath (yet) for scalars
return _maybe_wrap_data(data)
return cast("T_DuckArray", _maybe_wrap_data(data))

from xarray.core.dataarray import DataArray

Expand All @@ -252,7 +255,7 @@ def as_compatible_data(data, fastpath: bool = False):

if isinstance(data, NON_NUMPY_SUPPORTED_ARRAY_TYPES):
data = _possibly_convert_datetime_or_timedelta_index(data)
return _maybe_wrap_data(data)
return cast("T_DuckArray", _maybe_wrap_data(data))

if isinstance(data, tuple):
data = utils.to_0d_object_array(data)
Expand All @@ -279,7 +282,7 @@ def as_compatible_data(data, fastpath: bool = False):
if not isinstance(data, np.ndarray) and (
hasattr(data, "__array_function__") or hasattr(data, "__array_namespace__")
):
return data
return cast("T_DuckArray", data)

# validate whether the data is valid data types.
data = np.asarray(data)
Expand Down Expand Up @@ -335,7 +338,14 @@ class Variable(AbstractArray, NdimSizeLenMixin, VariableArithmetic):

__slots__ = ("_dims", "_data", "_attrs", "_encoding")

def __init__(self, dims, data, attrs=None, encoding=None, fastpath=False):
def __init__(
self,
dims,
data: T_DuckArray | ArrayLike,
attrs=None,
encoding=None,
fastpath=False,
):
"""
Parameters
----------
Expand All @@ -355,9 +365,9 @@ def __init__(self, dims, data, attrs=None, encoding=None, fastpath=False):
Well-behaved code to serialize a Variable should ignore
unrecognized encoding items.
"""
self._data = as_compatible_data(data, fastpath=fastpath)
self._data: T_DuckArray = as_compatible_data(data, fastpath=fastpath)
self._dims = self._parse_dimensions(dims)
self._attrs = None
self._attrs: dict[Any, Any] | None = None
self._encoding = None
if attrs is not None:
self.attrs = attrs
Expand Down Expand Up @@ -410,7 +420,7 @@ def _in_memory(self):
)

@property
def data(self) -> Any:
def data(self: T_Variable):
"""
The Variable's data as an array. The underlying array type
(e.g. dask, sparse, pint) is preserved.
Expand All @@ -429,12 +439,12 @@ def data(self) -> Any:
return self.values

@data.setter
def data(self, data):
def data(self: T_Variable, data: T_DuckArray | ArrayLike) -> None:
data = as_compatible_data(data)
if data.shape != self.shape:
if data.shape != self.shape: # type: ignore[attr-defined]
raise ValueError(
f"replacement data must match the Variable's shape. "
f"replacement data has shape {data.shape}; Variable has shape {self.shape}"
f"replacement data has shape {data.shape}; Variable has shape {self.shape}" # type: ignore[attr-defined]
)
self._data = data

Expand Down Expand Up @@ -996,7 +1006,7 @@ def reset_encoding(self: T_Variable) -> T_Variable:
return self._replace(encoding={})

def copy(
self: T_Variable, deep: bool = True, data: ArrayLike | None = None
self: T_Variable, deep: bool = True, data: T_DuckArray | ArrayLike | None = None
) -> T_Variable:
"""Returns a copy of this object.
Expand Down Expand Up @@ -1058,24 +1068,26 @@ def copy(
def _copy(
self: T_Variable,
deep: bool = True,
data: ArrayLike | None = None,
data: T_DuckArray | ArrayLike | None = None,
memo: dict[int, Any] | None = None,
) -> T_Variable:
if data is None:
ndata = self._data
data_old = self._data

if isinstance(ndata, indexing.MemoryCachedArray):
if isinstance(data_old, indexing.MemoryCachedArray):
# don't share caching between copies
ndata = indexing.MemoryCachedArray(ndata.array)
ndata = indexing.MemoryCachedArray(data_old.array)
else:
ndata = data_old

if deep:
ndata = copy.deepcopy(ndata, memo)

else:
ndata = as_compatible_data(data)
if self.shape != ndata.shape:
if self.shape != ndata.shape: # type: ignore[attr-defined]
raise ValueError(
f"Data shape {ndata.shape} must match shape of object {self.shape}"
f"Data shape {ndata.shape} must match shape of object {self.shape}" # type: ignore[attr-defined]
)

attrs = copy.deepcopy(self._attrs, memo) if deep else copy.copy(self._attrs)
Expand Down Expand Up @@ -1248,11 +1260,11 @@ def chunk(
inline_array=inline_array,
)

data = self._data
if chunkmanager.is_chunked_array(data):
data = chunkmanager.rechunk(data, chunks) # type: ignore[arg-type]
data_old = self._data
if chunkmanager.is_chunked_array(data_old):
data_chunked = chunkmanager.rechunk(data_old, chunks) # type: ignore[arg-type]
else:
if isinstance(data, indexing.ExplicitlyIndexed):
if isinstance(data_old, indexing.ExplicitlyIndexed):
# Unambiguously handle array storage backends (like NetCDF4 and h5py)
# that can't handle general array indexing. For example, in netCDF4 you
# can do "outer" indexing along two dimensions independent, which works
Expand All @@ -1261,20 +1273,22 @@ def chunk(
# Using OuterIndexer is a pragmatic choice: dask does not yet handle
# different indexing types in an explicit way:
# https://github.com/dask/dask/issues/2883
data = indexing.ImplicitToExplicitIndexingAdapter(
data, indexing.OuterIndexer
ndata = indexing.ImplicitToExplicitIndexingAdapter(
data_old, indexing.OuterIndexer
)
else:
ndata = data_old

if utils.is_dict_like(chunks):
chunks = tuple(chunks.get(n, s) for n, s in enumerate(data.shape))
chunks = tuple(chunks.get(n, s) for n, s in enumerate(ndata.shape))

data = chunkmanager.from_array(
data,
data_chunked = chunkmanager.from_array(
ndata,
chunks, # type: ignore[arg-type]
**_from_array_kwargs,
)

return self._replace(data=data)
return self._replace(data=data_chunked)

def to_numpy(self) -> np.ndarray:
"""Coerces wrapped data to numpy and returns a numpy.ndarray"""
Expand Down
55 changes: 51 additions & 4 deletions xarray/tests/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,8 @@ def test_dask_distributed_write_netcdf_with_dimensionless_variables(

@requires_cftime
@requires_netCDF4
def test_open_mfdataset_can_open_files_with_cftime_index(tmp_path):
@pytest.mark.parametrize("parallel", (True, False))
def test_open_mfdataset_can_open_files_with_cftime_index(parallel, tmp_path):
T = xr.cftime_range("20010101", "20010501", calendar="360_day")
Lon = np.arange(100)
data = np.random.random((T.size, Lon.size))
Expand All @@ -135,9 +136,55 @@ def test_open_mfdataset_can_open_files_with_cftime_index(tmp_path):
da.to_netcdf(file_path)
with cluster() as (s, [a, b]):
with Client(s["address"]):
for parallel in (False, True):
with xr.open_mfdataset(file_path, parallel=parallel) as tf:
assert_identical(tf["test"], da)
with xr.open_mfdataset(file_path, parallel=parallel) as tf:
assert_identical(tf["test"], da)


@requires_cftime
@requires_netCDF4
@pytest.mark.parametrize("parallel", (True, False))
def test_open_mfdataset_multiple_files_parallel_distributed(parallel, tmp_path):
lon = np.arange(100)
time = xr.cftime_range("20010101", periods=100, calendar="360_day")
data = np.random.random((time.size, lon.size))
da = xr.DataArray(data, coords={"time": time, "lon": lon}, name="test")

fnames = []
for i in range(0, 100, 10):
fname = tmp_path / f"test_{i}.nc"
da.isel(time=slice(i, i + 10)).to_netcdf(fname)
fnames.append(fname)

with cluster() as (s, [a, b]):
with Client(s["address"]):
with xr.open_mfdataset(
fnames, parallel=parallel, concat_dim="time", combine="nested"
) as tf:
assert_identical(tf["test"], da)


# TODO: move this to test_backends.py
@requires_cftime
@requires_netCDF4
@pytest.mark.parametrize("parallel", (True, False))
def test_open_mfdataset_multiple_files_parallel(parallel, tmp_path):
lon = np.arange(100)
time = xr.cftime_range("20010101", periods=100, calendar="360_day")
data = np.random.random((time.size, lon.size))
da = xr.DataArray(data, coords={"time": time, "lon": lon}, name="test")

fnames = []
for i in range(0, 100, 10):
fname = tmp_path / f"test_{i}.nc"
da.isel(time=slice(i, i + 10)).to_netcdf(fname)
fnames.append(fname)

for get in [dask.threaded.get, dask.multiprocessing.get, dask.local.get_sync, None]:
with dask.config.set(scheduler=get):
with xr.open_mfdataset(
fnames, parallel=parallel, concat_dim="time", combine="nested"
) as tf:
assert_identical(tf["test"], da)


@pytest.mark.parametrize("engine,nc_format", ENGINES_AND_FORMATS)
Expand Down
4 changes: 2 additions & 2 deletions xarray/tests/test_parallelcompat.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
guess_chunkmanager,
list_chunkmanagers,
)
from xarray.core.types import T_Chunks, T_NormalizedChunks
from xarray.core.types import T_Chunks, T_DuckArray, T_NormalizedChunks
from xarray.tests import has_dask, requires_dask


Expand Down Expand Up @@ -76,7 +76,7 @@ def normalize_chunks(
return normalize_chunks(chunks, shape, limit, dtype, previous_chunks)

def from_array(
self, data: np.ndarray, chunks: T_Chunks, **kwargs
self, data: T_DuckArray | np.typing.ArrayLike, chunks: T_Chunks, **kwargs
) -> DummyChunkedArray:
from dask import array as da

Expand Down
Loading

0 comments on commit fe87c21

Please sign in to comment.