diff --git a/cubed_xarray/cubedmanager.py b/cubed_xarray/cubedmanager.py index 9ea5212..9ad0735 100644 --- a/cubed_xarray/cubedmanager.py +++ b/cubed_xarray/cubedmanager.py @@ -9,6 +9,8 @@ import numpy as np +from tlz import partition + from xarray.core import utils from xarray.core.parallelcompat import ChunkManagerEntrypoint from xarray.core.pycompat import is_chunked_array, is_duck_dask_array @@ -192,9 +194,17 @@ def unify_chunks( *args: Any, # can't type this as mypy assumes args are all same type, but dask unify_chunks args alternate types **kwargs, ) -> tuple[dict[str, T_NormalizedChunks], list["CubedArray"]]: + from cubed.array_api import asarray from cubed.core import unify_chunks - return unify_chunks(*args, **kwargs) + # Ensure that args are Cubed arrays. Note that we do this here and not in Cubed, following + # https://numpy.org/neps/nep-0047-array-api-standard.html#the-asarray-asanyarray-pattern + arginds = [ + (asarray(a) if ind is not None else a, ind) for a, ind in partition(2, args) + ] + array_args = [item for pair in arginds for item in pair] + + return unify_chunks(*array_args, **kwargs) def store( self,