Skip to content

Commit

Permalink
Take allowed_mem into account when fusing primitive operations (#366)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomwhite authored Jan 30, 2024
1 parent 119c063 commit 1762d3c
Show file tree
Hide file tree
Showing 6 changed files with 143 additions and 20 deletions.
18 changes: 11 additions & 7 deletions cubed/core/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,21 +137,24 @@ def _create_lazy_zarr_arrays(self, dag):
# find all lazy zarr arrays in dag
all_pipeline_nodes = []
lazy_zarr_arrays = []
reserved_mem_values = []
allowed_mem = 0
reserved_mem = 0
for n, d in dag.nodes(data=True):
if "primitive_op" in d and d["primitive_op"].reserved_mem is not None:
reserved_mem_values.append(d["primitive_op"].reserved_mem)
if "primitive_op" in d:
all_pipeline_nodes.append(n)
allowed_mem = max(allowed_mem, d["primitive_op"].allowed_mem)
reserved_mem = max(reserved_mem, d["primitive_op"].reserved_mem)

if "target" in d and isinstance(d["target"], LazyZarrArray):
lazy_zarr_arrays.append(d["target"])

reserved_mem = max(reserved_mem_values, default=0)

if len(lazy_zarr_arrays) > 0:
# add new node and edges
name = "create-arrays"
op_name = name
primitive_op = create_zarr_arrays(lazy_zarr_arrays, reserved_mem)
primitive_op = create_zarr_arrays(
lazy_zarr_arrays, allowed_mem, reserved_mem
)
dag.add_node(
name,
name=name,
Expand Down Expand Up @@ -429,7 +432,7 @@ def create_zarr_array(lazy_zarr_array, *, config=None):
lazy_zarr_array.create(mode="a")


def create_zarr_arrays(lazy_zarr_arrays, reserved_mem):
def create_zarr_arrays(lazy_zarr_arrays, allowed_mem, reserved_mem):
# projected memory is size of largest dtype size (for a fill value)
projected_mem = (
max([lza.dtype.itemsize for lza in lazy_zarr_arrays], default=0) + reserved_mem
Expand All @@ -446,6 +449,7 @@ def create_zarr_arrays(lazy_zarr_arrays, reserved_mem):
pipeline=pipeline,
target_array=None,
projected_mem=projected_mem,
allowed_mem=allowed_mem,
reserved_mem=reserved_mem,
num_tasks=num_tasks,
fusable=False,
Expand Down
31 changes: 24 additions & 7 deletions cubed/primitive/blockwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from cubed.vendor.dask.blockwise import _get_coord_mapping, _make_dims, lol_product
from cubed.vendor.dask.core import flatten

from .types import CubedArrayProxy, PrimitiveOperation
from .types import CubedArrayProxy, MemoryModeller, PrimitiveOperation

sym_counter = 0

Expand Down Expand Up @@ -313,6 +313,7 @@ def general_blockwise(
pipeline=pipeline,
target_array=target_array,
projected_mem=projected_mem,
allowed_mem=allowed_mem,
reserved_mem=reserved_mem,
num_tasks=num_tasks,
fusable=fusable,
Expand Down Expand Up @@ -343,12 +344,27 @@ def can_fuse_multiple_primitive_ops(
if is_fuse_candidate(primitive_op) and all(
is_fuse_candidate(p) for p in predecessor_primitive_ops
):
# if the peak projected memory for running all the predecessor ops in order is
# larger than allowed_mem then we can't fuse
if peak_projected_mem(predecessor_primitive_ops) > primitive_op.allowed_mem:
return False
return all(
primitive_op.num_tasks == p.num_tasks for p in predecessor_primitive_ops
)
return False


def peak_projected_mem(primitive_ops):
"""Calculate the peak projected memory for running a series of primitive ops
and retaining their return values in memory."""
memory_modeller = MemoryModeller()
for p in primitive_ops:
memory_modeller.allocate(p.projected_mem)
chunkmem = chunk_memory(p.target_array.dtype, p.target_array.chunks)
memory_modeller.free(p.projected_mem - chunkmem)
return memory_modeller.peak_mem


def fuse(
primitive_op1: PrimitiveOperation, primitive_op2: PrimitiveOperation
) -> PrimitiveOperation:
Expand Down Expand Up @@ -380,7 +396,8 @@ def fused_func(*args):

target_array = primitive_op2.target_array
projected_mem = max(primitive_op1.projected_mem, primitive_op2.projected_mem)
reserved_mem = max(primitive_op1.reserved_mem, primitive_op2.reserved_mem)
allowed_mem = primitive_op2.allowed_mem
reserved_mem = primitive_op2.reserved_mem
num_tasks = primitive_op2.num_tasks

pipeline = CubedPipeline(
Expand All @@ -393,6 +410,7 @@ def fused_func(*args):
pipeline=pipeline,
target_array=target_array,
projected_mem=projected_mem,
allowed_mem=allowed_mem,
reserved_mem=reserved_mem,
num_tasks=num_tasks,
fusable=True,
Expand Down Expand Up @@ -467,12 +485,10 @@ def fused_func(*args):
target_array = primitive_op.target_array
projected_mem = max(
primitive_op.projected_mem,
*(p.projected_mem for p in predecessor_primitive_ops if p is not None),
)
reserved_mem = max(
primitive_op.reserved_mem,
*(p.reserved_mem for p in predecessor_primitive_ops if p is not None),
peak_projected_mem(p for p in predecessor_primitive_ops if p is not None),
)
allowed_mem = primitive_op.allowed_mem
reserved_mem = primitive_op.reserved_mem
num_tasks = primitive_op.num_tasks

fused_pipeline = CubedPipeline(
Expand All @@ -485,6 +501,7 @@ def fused_func(*args):
pipeline=fused_pipeline,
target_array=target_array,
projected_mem=projected_mem,
allowed_mem=allowed_mem,
reserved_mem=reserved_mem,
num_tasks=num_tasks,
fusable=True,
Expand Down
13 changes: 10 additions & 3 deletions cubed/primitive/rechunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def rechunk(
num_tasks = total_chunks(write_proxy.array.shape, write_proxy.chunks)
return [
spec_to_primitive_op(
copy_spec, target, projected_mem, reserved_mem, num_tasks
copy_spec, target, projected_mem, allowed_mem, reserved_mem, num_tasks
)
]

Expand All @@ -81,13 +81,18 @@ def rechunk(
copy_spec1 = CubedCopySpec(read_proxy, int_proxy)
num_tasks = total_chunks(copy_spec1.write.array.shape, copy_spec1.write.chunks)
op1 = spec_to_primitive_op(
copy_spec1, intermediate, projected_mem, reserved_mem, num_tasks
copy_spec1,
intermediate,
projected_mem,
allowed_mem,
reserved_mem,
num_tasks,
)

copy_spec2 = CubedCopySpec(int_proxy, write_proxy)
num_tasks = total_chunks(copy_spec2.write.array.shape, copy_spec2.write.chunks)
op2 = spec_to_primitive_op(
copy_spec2, target, projected_mem, reserved_mem, num_tasks
copy_spec2, target, projected_mem, allowed_mem, reserved_mem, num_tasks
)

return [op1, op2]
Expand Down Expand Up @@ -191,6 +196,7 @@ def spec_to_primitive_op(
spec: CubedCopySpec,
target_array: Any,
projected_mem: int,
allowed_mem: int,
reserved_mem: int,
num_tasks: int,
) -> PrimitiveOperation:
Expand All @@ -206,6 +212,7 @@ def spec_to_primitive_op(
pipeline=pipeline,
target_array=target_array,
projected_mem=projected_mem,
allowed_mem=allowed_mem,
reserved_mem=reserved_mem,
num_tasks=num_tasks,
fusable=False,
Expand Down
22 changes: 22 additions & 0 deletions cubed/primitive/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,13 @@ class PrimitiveOperation:
projected_mem: int
"""An upper bound of the memory needed to run a task, in bytes."""

allowed_mem: int
"""
The total memory available to a worker for running a task, in bytes.
This includes any ``reserved_mem`` that has been set.
"""

reserved_mem: int
"""The memory reserved on a worker for non-data use when running a task, in bytes."""

Expand Down Expand Up @@ -51,3 +58,18 @@ class CubedCopySpec:

read: CubedArrayProxy
write: CubedArrayProxy


class MemoryModeller:
"""Models peak memory usage for a series of operations."""

current_mem: int = 0
peak_mem: int = 0

def allocate(self, num_bytes):
self.current_mem += num_bytes
self.peak_mem = max(self.peak_mem, self.current_mem)

def free(self, num_bytes):
self.current_mem -= num_bytes
self.peak_mem = max(self.peak_mem, self.current_mem)
15 changes: 15 additions & 0 deletions cubed/tests/primitive/test_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from cubed.primitive.types import MemoryModeller


def test_memory_modeller():
modeller = MemoryModeller()
assert modeller.current_mem == 0
assert modeller.peak_mem == 0

modeller.allocate(100)
assert modeller.current_mem == 100
assert modeller.peak_mem == 100

modeller.free(50)
assert modeller.current_mem == 50
assert modeller.peak_mem == 100
64 changes: 61 additions & 3 deletions cubed/tests/test_mem_utilization.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import math
import shutil
from functools import partial, reduce

import pytest

from cubed.core.ops import partial_reduce
from cubed.core.optimization import multiple_inputs_optimize_dag

pytest.importorskip("lithops")

Expand Down Expand Up @@ -87,6 +89,56 @@ def test_add(tmp_path, spec):
run_operation(tmp_path, "add", c)


@pytest.mark.slow
def test_add_reduce_left(tmp_path, spec):
# Perform the `add` operation repeatedly on pairs of arrays, also known as fold left.
# See https://en.wikipedia.org/wiki/Fold_(higher-order_function)
#
# o o
# \ /
# o o
# \ /
# o o
# \ /
# o
#
# Fusing fold left operations will result in a single fused operation.
n_arrays = 10
arrs = [
cubed.random.random((10000, 10000), chunks=(5000, 5000), spec=spec)
for _ in range(n_arrays)
]
result = reduce(lambda x, y: xp.add(x, y), arrs)
opt_fn = partial(multiple_inputs_optimize_dag, max_total_source_arrays=n_arrays * 2)
run_operation(tmp_path, "add_reduce_left", result, optimize_function=opt_fn)


@pytest.mark.slow
def test_add_reduce_right(tmp_path, spec):
# Perform the `add` operation repeatedly on pairs of arrays, also known as fold right.
# See https://en.wikipedia.org/wiki/Fold_(higher-order_function)
#
# o o
# \ /
# o o
# \ /
# o o
# \ /
# o
#
# Note that fusing fold right operations will result in unbounded memory usage unless care
# is taken to limit fusion - which `multiple_inputs_optimize_dag` will do, with the result
# that there is more than one fused operation (not a single fused oepration).
n_arrays = 10
arrs = [
cubed.random.random((10000, 10000), chunks=(5000, 5000), spec=spec)
for _ in range(n_arrays)
]
result = reduce(lambda x, y: xp.add(y, x), reversed(arrs))
opt_fn = partial(multiple_inputs_optimize_dag, max_total_source_arrays=n_arrays * 2)
run_operation(tmp_path, "add_reduce_right", result, optimize_function=opt_fn)


@pytest.mark.slow
def test_negative(tmp_path, spec):
a = cubed.random.random(
Expand Down Expand Up @@ -220,13 +272,19 @@ def test_sum_partial_reduce(tmp_path, spec):
# Internal functions


def run_operation(tmp_path, name, result_array):
def run_operation(tmp_path, name, result_array, *, optimize_function=None):
# result_array.visualize(f"cubed-{name}-unoptimized", optimize_graph=False)
# result_array.visualize(f"cubed-{name}")
# result_array.visualize(f"cubed-{name}", optimize_function=optimize_function)
executor = LithopsDagExecutor(config=LITHOPS_LOCAL_CONFIG)
hist = HistoryCallback()
# use store=None to write to temporary zarr
cubed.to_zarr(result_array, store=None, executor=executor, callbacks=[hist])
cubed.to_zarr(
result_array,
store=None,
executor=executor,
callbacks=[hist],
optimize_function=optimize_function,
)

df = hist.stats_df
print(df)
Expand Down

0 comments on commit 1762d3c

Please sign in to comment.