Skip to content

Commit

Permalink
Ensure args are Cubed arrays in unify_chunks
Browse files Browse the repository at this point in the history
  • Loading branch information
tomwhite committed Sep 19, 2023
1 parent d1c773b commit 4f28e57
Showing 1 changed file with 11 additions and 1 deletion.
12 changes: 11 additions & 1 deletion cubed_xarray/cubedmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

import numpy as np

from tlz import partition

from xarray.core import utils
from xarray.core.parallelcompat import ChunkManagerEntrypoint
from xarray.core.pycompat import is_chunked_array, is_duck_dask_array
Expand Down Expand Up @@ -192,9 +194,17 @@ def unify_chunks(
*args: Any, # can't type this as mypy assumes args are all same type, but dask unify_chunks args alternate types
**kwargs,
) -> tuple[dict[str, T_NormalizedChunks], list["CubedArray"]]:
from cubed.array_api import asarray
from cubed.core import unify_chunks

return unify_chunks(*args, **kwargs)
# Ensure that args are Cubed arrays. Note that we do this here and not in Cubed, following
# https://numpy.org/neps/nep-0047-array-api-standard.html#the-asarray-asanyarray-pattern
arginds = [
(asarray(a) if ind is not None else a, ind) for a, ind in partition(2, args)
]
array_args = [item for pair in arginds for item in pair]

return unify_chunks(*array_args, **kwargs)

def store(
self,
Expand Down

0 comments on commit 4f28e57

Please sign in to comment.