Skip to content

Commit

Permalink
Use general_blockwise in case of chunk-aligned selections in index
Browse files Browse the repository at this point in the history
This allows the optimizer to perform fusion, which is not the case with
the existing `map_direct` implementation.
  • Loading branch information
tomwhite committed Sep 30, 2024
1 parent 87db8ba commit ed3ddad
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 15 deletions.
149 changes: 135 additions & 14 deletions cubed/core/ops.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import builtins
import math
import numbers
from dataclasses import dataclass
from functools import partial
from itertools import product
from numbers import Integral, Number
Expand All @@ -25,9 +26,11 @@
from cubed.primitive.rechunk import rechunk as primitive_rechunk
from cubed.spec import spec_from_config
from cubed.storage.backend import open_backend_array
from cubed.types import T_RegularChunks, T_Shape
from cubed.utils import (
_concatenate2,
array_memory,
array_size,
get_item,
offset_to_block_id,
to_chunksize,
Expand Down Expand Up @@ -484,6 +487,11 @@ def merged_chunk_len_for_indexer(ia, c):
if shape == x.shape:
# no op case (except possibly newaxis applied below)
out = x
elif array_size(shape) == 0:
# empty output case
from cubed.array_api.creation_functions import empty

out = empty(shape, dtype=x.dtype, spec=x.spec)
else:
dtype = x.dtype
chunks = tuple(
Expand All @@ -494,21 +502,68 @@ def merged_chunk_len_for_indexer(ia, c):

target_chunks = normalize_chunks(chunks, shape, dtype=dtype)

# memory allocated by reading one chunk from input array
# note that although the output chunk will overlap multiple input chunks, zarr will
# read the chunks in series, reusing the buffer
extra_projected_mem = x.chunkmem
if _is_chunk_aligned_selection(idx):
# use general_blockwise, which allows more opportunities for optimization than map_direct

out = map_direct(
_read_index_chunk,
x,
shape=shape,
dtype=dtype,
chunks=target_chunks,
extra_projected_mem=extra_projected_mem,
target_chunks=target_chunks,
selection=selection,
)
from cubed.array_api.creation_functions import offsets_virtual_array

# general_blockwise doesn't support block_id, so emulate it ourselves
numblocks = tuple(map(len, target_chunks))
offsets = offsets_virtual_array(numblocks, x.spec)

def key_function(out_key):
out_coords = out_key[1:]

# compute the selection on x required to get the relevant chunk for out_coords
in_sel = _target_chunk_selection(target_chunks, out_coords, selection)

# use a Zarr BasicIndexer to convert this to input coordinates
indexer = create_basic_indexer(
in_sel, x.zarray_maybe_lazy.shape, x.zarray_maybe_lazy.chunks
)

offset_in_key = ((offsets.name,) + out_coords,)
return (
tuple((x.name,) + chunk_coords for (chunk_coords, _, _) in indexer)
+ offset_in_key
)

# since selection is chunk-aligned, we know that we only read one block of x
num_input_blocks = (1, 1) # x, offsets

out = general_blockwise(
_assemble_index_chunk,
key_function,
x,
offsets,
shapes=[shape],
dtypes=[x.dtype],
chunkss=[target_chunks],
num_input_blocks=num_input_blocks,
target_chunks=target_chunks,
selection=selection,
in_shape=x.shape,
in_chunksize=x.chunksize,
)
else:
# use map_direct, which can't be fused
# (note that it should be possible to re-write as general_blockwise with more work)

# memory allocated by reading one chunk from input array
# note that although the output chunk will overlap multiple input chunks, zarr will
# read the chunks in series, reusing the buffer
extra_projected_mem = x.chunkmem

out = map_direct(
_read_index_chunk,
x,
shape=shape,
dtype=dtype,
chunks=target_chunks,
extra_projected_mem=extra_projected_mem,
target_chunks=target_chunks,
selection=selection,
)

# merge chunks for any dims with step > 1 so they are
# the same size as the input (or slightly smaller due to rounding)
Expand All @@ -528,6 +583,72 @@ def merged_chunk_len_for_indexer(ia, c):
return out


def _is_chunk_aligned_selection(idx: ndindex.Tuple):
return all(
isinstance(ia, ndindex.Integer)
or (
isinstance(ia, ndindex.Slice)
and ia.start == 0
and (ia.step is None or ia.step == 1)
)
for ia in idx.args
)


def create_basic_indexer(selection, shape, chunks):
if zarr.__version__[0] == "3":
from zarr.core.chunk_grids import RegularChunkGrid
from zarr.core.indexing import BasicIndexer

return BasicIndexer(selection, shape, RegularChunkGrid(chunk_shape=chunks))
else:
from zarr.indexing import BasicIndexer

return BasicIndexer(selection, ZarrArrayIndexingAdaptor(shape, chunks))


@dataclass
class ZarrArrayIndexingAdaptor:
_shape: T_Shape
_chunks: T_RegularChunks

@classmethod
def from_zarr_array(cls, zarray):
return cls(zarray.shape, zarray.chunks)


def _assemble_index_chunk(
*arrs,
target_chunks=None,
selection=None,
in_shape=None,
in_chunksize=None,
):
# last array contains the offset for the block_id
offset = int(arrs[-1]) # convert from 0-d array
numblocks = tuple(map(len, target_chunks))
block_id = offset_to_block_id(offset, numblocks)

arrs = arrs[:-1] # drop offset array

# compute the selection on x required to get the relevant chunk for out_coords
out_coords = block_id
in_sel = _target_chunk_selection(target_chunks, out_coords, selection)

# use a Zarr BasicIndexer to convert this to input coordinates
indexer = create_basic_indexer(in_sel, in_shape, in_chunksize)

shape = indexer.shape
out = np.empty_like(arrs[0], shape=shape)

if array_size(shape) > 0:
_, lchunk_selection, lout_selection = zip(*indexer)
for ai, chunk_select, out_select in zip(arrs, lchunk_selection, lout_selection):
out[out_select] = ai[chunk_select]

return out


def _read_index_chunk(
x,
*arrays,
Expand Down
9 changes: 9 additions & 0 deletions cubed/tests/test_mem_utilization.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,15 @@ def test_index(tmp_path, spec, executor):
run_operation(tmp_path, executor, "index", b)


@pytest.mark.slow
def test_index_chunk_aligned(tmp_path, spec, executor):
a = cubed.random.random(
(10000, 10000), chunks=(5000, 5000), spec=spec
) # 200MB chunks
b = a[0:5000, :]
run_operation(tmp_path, executor, "index_chunk_aligned", b)


@pytest.mark.slow
def test_index_step(tmp_path, spec, executor):
a = cubed.random.random(
Expand Down
8 changes: 7 additions & 1 deletion cubed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,15 @@
from functools import partial
from itertools import islice
from math import prod
from operator import add
from operator import add, mul
from pathlib import Path
from posixpath import join
from typing import Dict, Tuple, Union, cast
from urllib.parse import quote, unquote, urlsplit, urlunsplit

import numpy as np
import tlz as toolz
from toolz import reduce

from cubed.backend_array_api import namespace as nxp
from cubed.types import T_DType, T_RectangularChunks, T_RegularChunks, T_Shape
Expand All @@ -41,6 +42,11 @@ def chunk_memory(arr) -> int:
)


def array_size(shape: T_Shape) -> int:
"""Number of elements in an array."""
return reduce(mul, shape, 1)


def offset_to_block_id(offset: int, numblocks: Tuple[int, ...]) -> Tuple[int, ...]:
"""Convert an index offset to a block ID (chunk coordinates)."""
return tuple(int(i) for i in np.unravel_index(offset, numblocks))
Expand Down

0 comments on commit ed3ddad

Please sign in to comment.