Skip to content

Commit

Permalink
Implement repeat (#610)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomwhite authored Nov 8, 2024
1 parent 25bc395 commit a2b1053
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 2 deletions.
2 changes: 1 addition & 1 deletion api_status.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ This table shows which parts of the the [Array API](https://data-apis.org/array-
| | `expand_dims` | :white_check_mark: | | |
| | `flip` | :white_check_mark: | | |
| | `permute_dims` | :white_check_mark: | | |
| | `repeat` | :x: | 2023.12 | |
| | `repeat` | :white_check_mark: | | |
| | `reshape` | :white_check_mark: | | Partial implementation |
| | `roll` | :white_check_mark: | | |
| | `squeeze` | :white_check_mark: | | |
Expand Down
2 changes: 2 additions & 0 deletions cubed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@
flip,
moveaxis,
permute_dims,
repeat,
reshape,
roll,
squeeze,
Expand All @@ -311,6 +312,7 @@
"flip",
"moveaxis",
"permute_dims",
"repeat",
"reshape",
"roll",
"squeeze",
Expand Down
4 changes: 3 additions & 1 deletion cubed/array_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@
flip,
moveaxis,
permute_dims,
repeat,
reshape,
roll,
squeeze,
Expand All @@ -253,6 +254,7 @@
"flip",
"moveaxis",
"permute_dims",
"repeat",
"reshape",
"roll",
"squeeze",
Expand All @@ -264,7 +266,7 @@

__all__ += ["argmax", "argmin", "where"]

from .statistical_functions import max, mean, min, prod, sum, std, var
from .statistical_functions import max, mean, min, prod, std, sum, var

__all__ += ["max", "mean", "min", "prod", "std", "sum", "var"]

Expand Down
44 changes: 44 additions & 0 deletions cubed/array_api/manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,50 @@ def permute_dims(x, /, axes):
)


def repeat(x, repeats, /, *, axis=0):
if axis is None:
x = flatten(x)
axis = 0

shape = x.shape[:axis] + (x.shape[axis] * repeats,) + x.shape[axis + 1 :]
chunks = normalize_chunks(x.chunksize, shape=shape, dtype=x.dtype)

# This implementation calls nxp.repeat in every output block, which is 'repeats' times
# more than necessary than if we had a primitive op that could write multiple blocks.

def key_function(out_key):
out_coords = out_key[1:]
in_coords = tuple(
bi // repeats if i == axis else bi for i, bi in enumerate(out_coords)
)
return ((x.name, *in_coords),)

# extra memory from calling 'nxp.repeat' on a chunk
extra_projected_mem = x.chunkmem * repeats
return general_blockwise(
_repeat,
key_function,
x,
shapes=[shape],
dtypes=[x.dtype],
chunkss=[chunks],
extra_projected_mem=extra_projected_mem,
repeats=repeats,
axis=axis,
chunksize=x.chunksize,
)


def _repeat(x, repeats, axis=None, chunksize=None, block_id=None):
out = nxp.repeat(x, repeats, axis=axis)
bi = block_id[axis] % repeats
ind = tuple(
slice(bi * chunksize[i], (bi + 1) * chunksize[i]) if i == axis else slice(None)
for i in range(x.ndim)
)
return out[ind]


def reshape(x, /, shape, *, copy=None):
# based on dask reshape

Expand Down
9 changes: 9 additions & 0 deletions cubed/tests/test_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,15 @@ def test_permute_dims(spec, executor):
)


def test_repeat(spec):
a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2), spec=spec)
b = xp.repeat(a, 3, axis=1)
assert_array_equal(
b.compute(),
np.repeat(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), 3, axis=1),
)


def test_reshape(spec, executor):
a = xp.arange(12, chunks=4, spec=spec)
b = xp.reshape(a, (3, 4))
Expand Down
9 changes: 9 additions & 0 deletions cubed/tests/test_mem_utilization.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,15 @@ def test_flip_multiple_axes(tmp_path, spec, executor):
run_operation(tmp_path, executor, "flip_multiple_axes", b)


@pytest.mark.slow
def test_repeat(tmp_path, spec, executor):
a = cubed.random.random(
(10000, 10000), chunks=(5000, 5000), spec=spec
) # 200MB chunks
b = xp.repeat(a, 3, axis=0)
run_operation(tmp_path, executor, "repeat", b)


@pytest.mark.slow
def test_reshape(tmp_path, spec, executor):
a = cubed.random.random(
Expand Down

0 comments on commit a2b1053

Please sign in to comment.