From 4f28e574bf9bbba2fddeff8ca1978967f652f69f Mon Sep 17 00:00:00 2001 From: Tom White Date: Tue, 19 Sep 2023 10:44:42 +0100 Subject: [PATCH] Ensure args are Cubed arrays in `unify_chunks` --- cubed_xarray/cubedmanager.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/cubed_xarray/cubedmanager.py b/cubed_xarray/cubedmanager.py index 9ea5212..9ad0735 100644 --- a/cubed_xarray/cubedmanager.py +++ b/cubed_xarray/cubedmanager.py @@ -9,6 +9,8 @@ import numpy as np +from tlz import partition + from xarray.core import utils from xarray.core.parallelcompat import ChunkManagerEntrypoint from xarray.core.pycompat import is_chunked_array, is_duck_dask_array @@ -192,9 +194,17 @@ def unify_chunks( *args: Any, # can't type this as mypy assumes args are all same type, but dask unify_chunks args alternate types **kwargs, ) -> tuple[dict[str, T_NormalizedChunks], list["CubedArray"]]: + from cubed.array_api import asarray from cubed.core import unify_chunks - return unify_chunks(*args, **kwargs) + # Ensure that args are Cubed arrays. Note that we do this here and not in Cubed, following + # https://numpy.org/neps/nep-0047-array-api-standard.html#the-asarray-asanyarray-pattern + arginds = [ + (asarray(a) if ind is not None else a, ind) for a, ind in partition(2, args) + ] + array_args = [item for pair in arginds for item in pair] + + return unify_chunks(*array_args, **kwargs) def store( self,