-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add scan. #531
base: main
Are you sure you want to change the base?
Add scan. #531
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -5,7 +5,7 @@ | |||||||||||||||
from itertools import product | ||||||||||||||||
from numbers import Integral, Number | ||||||||||||||||
from operator import add | ||||||||||||||||
from typing import TYPE_CHECKING, Any, Sequence, Union | ||||||||||||||||
from typing import TYPE_CHECKING, Any, Callable, Sequence, Union | ||||||||||||||||
from warnings import warn | ||||||||||||||||
|
||||||||||||||||
import ndindex | ||||||||||||||||
|
@@ -22,6 +22,7 @@ | |||||||||||||||
from cubed.core.plan import Plan, new_temp_path | ||||||||||||||||
from cubed.primitive.blockwise import blockwise as primitive_blockwise | ||||||||||||||||
from cubed.primitive.blockwise import general_blockwise as primitive_general_blockwise | ||||||||||||||||
from cubed.primitive.blockwise import key_to_slices | ||||||||||||||||
from cubed.primitive.rechunk import rechunk as primitive_rechunk | ||||||||||||||||
from cubed.spec import spec_from_config | ||||||||||||||||
from cubed.storage.backend import open_backend_array | ||||||||||||||||
|
@@ -1442,3 +1443,120 @@ def smallest_blockdim(blockdims): | |||||||||||||||
m = ntd[0] | ||||||||||||||||
out = ntd | ||||||||||||||||
return out | ||||||||||||||||
|
||||||||||||||||
|
||||||||||||||||
def wrapper_binop( | ||||||||||||||||
out: np.ndarray, | ||||||||||||||||
left: Array, | ||||||||||||||||
right: Array, | ||||||||||||||||
*, | ||||||||||||||||
binop: Callable, | ||||||||||||||||
block_id: tuple[int, ...], | ||||||||||||||||
axis: int, | ||||||||||||||||
identity: Any, | ||||||||||||||||
) -> Array: | ||||||||||||||||
# print(type(out), out.shape) | ||||||||||||||||
# print(block_id) | ||||||||||||||||
# print("left", left) | ||||||||||||||||
# print("right", right) | ||||||||||||||||
left_slicer = key_to_slices(block_id, left) | ||||||||||||||||
right_slicer = list(left_slicer) | ||||||||||||||||
|
||||||||||||||||
# For the first block, we add the identity element | ||||||||||||||||
# For all other blocks `k`, we add the `k-1` element along `axis` | ||||||||||||||||
right_slicer[axis] = slice(block_id[axis] - 1, block_id[axis]) | ||||||||||||||||
right_slicer = tuple(right_slicer) | ||||||||||||||||
right_ = right[right_slicer] if block_id[axis] > 0 else identity | ||||||||||||||||
# print("left", left[left_slicer].shape) | ||||||||||||||||
# print("right", right_.shape) | ||||||||||||||||
return binop(left[left_slicer], right_) | ||||||||||||||||
|
||||||||||||||||
|
||||||||||||||||
def scan( | ||||||||||||||||
array: "Array", | ||||||||||||||||
func: Callable, | ||||||||||||||||
*, | ||||||||||||||||
preop: Callable, | ||||||||||||||||
binop: Callable, | ||||||||||||||||
identity: Any, | ||||||||||||||||
axis: int, | ||||||||||||||||
dtype=None, | ||||||||||||||||
) -> Array: | ||||||||||||||||
""" | ||||||||||||||||
Generic parallel scan. | ||||||||||||||||
|
||||||||||||||||
Parameters | ||||||||||||||||
---------- | ||||||||||||||||
x: Cubed Array | ||||||||||||||||
func: callable | ||||||||||||||||
Scan or cumulative function like np.cumsum or np.cumprod | ||||||||||||||||
preop: callable | ||||||||||||||||
Function applied blockwise that reduces each block to a single value | ||||||||||||||||
along ``axis``. For ``np.cumsum`` this is ``np.sum`` and for ``np.cumprod`` this is ``np.prod``. | ||||||||||||||||
binop: callable | ||||||||||||||||
Associated binary operator like ``np.cumsum->add`` or ``np.cumprod->mul`` | ||||||||||||||||
identity: Any | ||||||||||||||||
Associated identity element more scan like 0 for ``np.cumsum`` and 1 for ``np.cumprod``. | ||||||||||||||||
axis: int | ||||||||||||||||
dtype: dtype | ||||||||||||||||
|
||||||||||||||||
Notes | ||||||||||||||||
----- | ||||||||||||||||
This method uses a variant of the Blelloch (1989) alogrithm. | ||||||||||||||||
|
||||||||||||||||
Returns | ||||||||||||||||
------- | ||||||||||||||||
Array | ||||||||||||||||
|
||||||||||||||||
See also | ||||||||||||||||
-------- | ||||||||||||||||
cumsum | ||||||||||||||||
cumprod | ||||||||||||||||
""" | ||||||||||||||||
# Blelloch (1990) out-of-core algorithm. | ||||||||||||||||
# 1. First, scan blockwise | ||||||||||||||||
scanned = blockwise(func, "ij", array, "ij", axis=axis) | ||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using |
||||||||||||||||
# If there is only a single chunk, we can be done | ||||||||||||||||
if array.numblocks[-1] == 1: | ||||||||||||||||
return scanned | ||||||||||||||||
|
||||||||||||||||
# 2. Calculate the blockwise reduction using `preop` | ||||||||||||||||
# TODO: could also merge(1,2) by returning {"scan": np.cumsum(array), "preop": np.sum(array)} in `scanned` | ||||||||||||||||
reduced = blockwise( | ||||||||||||||||
preop, "ij", array, "ij", axis=axis, adjust_chunks={"j": 1}, keepdims=True | ||||||||||||||||
) | ||||||||||||||||
|
||||||||||||||||
# 3. Now scan `reduced` to generate the increments for each block of `scanned`. | ||||||||||||||||
# Here we diverge from Blelloch, who runs a balanced tree algorithm to calculate the scan. | ||||||||||||||||
# Instead we generalize recursively apply the scan to `reduced`. | ||||||||||||||||
# 3a. First we merge to a decent intermediate chunksize since reduced.chunksize[axis] == 1 | ||||||||||||||||
new_chunksize = min(reduced.shape[axis], reduced.chunksize[axis] * 5) | ||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. need input here on choosing a new intermediate chunksize to rechunk to based on memory info. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There are a couple of things to consider here: the number of chunks to combine at each stage, and the memory limits. The first is like For the second, we should make sure the new chunksize is no larger than There is an error case where this memory constraint means the new chunksize is no larger than the existing one, so the computation can't proceed. The user can fix this either by reducing the chunksize or by increasing the memory. This is similar to this case: Lines 985 to 991 in 88c5dc4
|
||||||||||||||||
new_chunks = reduced.chunksize[:-1] + (new_chunksize,) | ||||||||||||||||
merged = merge_chunks(reduced, new_chunks) | ||||||||||||||||
|
||||||||||||||||
# 3b. Recursively scan this merged array to generate the increment for each block of `scanned` | ||||||||||||||||
increment = scan( | ||||||||||||||||
merged, func, preop=preop, binop=binop, identity=identity, axis=axis | ||||||||||||||||
) | ||||||||||||||||
|
||||||||||||||||
# 4. Back to Blelloch. Now that we have the increment, add it to the blocks of `scanned`. | ||||||||||||||||
# Use map_direct since the chunks of increment and scanned aren't aligned anymore. | ||||||||||||||||
assert increment.shape[axis] == scanned.numblocks[axis] | ||||||||||||||||
# 5. Bada-bing, bada-boom. | ||||||||||||||||
return map_direct( | ||||||||||||||||
partial(wrapper_binop, binop=binop, axis=axis, identity=identity), | ||||||||||||||||
scanned, | ||||||||||||||||
increment, | ||||||||||||||||
shape=scanned.shape, | ||||||||||||||||
dtype=scanned.dtype, | ||||||||||||||||
chunks=scanned.chunks, | ||||||||||||||||
extra_projected_mem=scanned.chunkmem * 2, # arbitrary | ||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. need input here too. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should be the memory allocated to read from the side inputs ( extra_projected_mem=scanned.chunkmem * 2 + increment.chunkmem * 2 (There's an open issue #288 to make this a bit more transparent.) |
||||||||||||||||
) | ||||||||||||||||
|
||||||||||||||||
|
||||||||||||||||
# result = scan( | ||||||||||||||||
# array, preop=np.sum, func=np.cumsum, binop=np.add, identity=0, axis=-1 | ||||||||||||||||
# ) | ||||||||||||||||
# print(result) | ||||||||||||||||
# print(result.compute()) | ||||||||||||||||
# np.testing.assert_equal(result, np.cumsum(array.compute(), axis=-1)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe call something like
_scan_binop
to link it to the scan implementation? I've been using a naming convention like that elsewhere in the file.