Skip to content

Commit

Permalink
Introduce max_total_num_input_blocks as a heuristic to control fusion (
Browse files Browse the repository at this point in the history
  • Loading branch information
tomwhite authored Feb 12, 2024
1 parent f3295e7 commit 9ad8fbe
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 16 deletions.
29 changes: 25 additions & 4 deletions cubed/core/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,13 @@ def is_fusable(node_dict):


def can_fuse_predecessors(
dag, name, *, max_total_source_arrays=4, always_fuse=None, never_fuse=None
dag,
name,
*,
max_total_source_arrays=4,
max_total_num_input_blocks=None,
always_fuse=None,
never_fuse=None,
):
nodes = dict(dag.nodes(data=True))

Expand Down Expand Up @@ -130,12 +136,20 @@ def can_fuse_predecessors(
if is_fusable(nodes[pre])
]
return can_fuse_multiple_primitive_ops(
nodes[name]["primitive_op"], *predecessor_primitive_ops
nodes[name]["primitive_op"],
predecessor_primitive_ops,
max_total_num_input_blocks=max_total_num_input_blocks,
)


def fuse_predecessors(
dag, name, *, max_total_source_arrays=4, always_fuse=None, never_fuse=None
dag,
name,
*,
max_total_source_arrays=4,
max_total_num_input_blocks=None,
always_fuse=None,
never_fuse=None,
):
"""Fuse a node with its immediate predecessors."""

Expand All @@ -144,6 +158,7 @@ def fuse_predecessors(
dag,
name,
max_total_source_arrays=max_total_source_arrays,
max_total_num_input_blocks=max_total_num_input_blocks,
always_fuse=always_fuse,
never_fuse=never_fuse,
):
Expand Down Expand Up @@ -195,14 +210,20 @@ def fuse_predecessors(


def multiple_inputs_optimize_dag(
dag, *, max_total_source_arrays=4, always_fuse=None, never_fuse=None
dag,
*,
max_total_source_arrays=4,
max_total_num_input_blocks=None,
always_fuse=None,
never_fuse=None,
):
"""Fuse multiple inputs."""
for name in list(nx.topological_sort(dag)):
dag = fuse_predecessors(
dag,
name,
max_total_source_arrays=max_total_source_arrays,
max_total_num_input_blocks=max_total_num_input_blocks,
always_fuse=always_fuse,
never_fuse=never_fuse,
)
Expand Down
20 changes: 16 additions & 4 deletions cubed/primitive/blockwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,10 @@ def can_fuse_primitive_ops(


def can_fuse_multiple_primitive_ops(
primitive_op: PrimitiveOperation, *predecessor_primitive_ops: PrimitiveOperation
primitive_op: PrimitiveOperation,
predecessor_primitive_ops: List[PrimitiveOperation],
*,
max_total_num_input_blocks: Optional[int] = None,
) -> bool:
if is_fuse_candidate(primitive_op) and all(
is_fuse_candidate(p) for p in predecessor_primitive_ops
Expand All @@ -368,9 +371,18 @@ def can_fuse_multiple_primitive_ops(
num_input_blocks = primitive_op.pipeline.config.num_input_blocks
if not all(num_input_blocks[0] == n for n in num_input_blocks):
return False
return all(
primitive_op.num_tasks == p.num_tasks for p in predecessor_primitive_ops
)
if max_total_num_input_blocks is None:
# If max total input blocks not specified, then only fuse if num
# tasks of predecessor ops match.
return all(
primitive_op.num_tasks == p.num_tasks for p in predecessor_primitive_ops
)
else:
total_num_input_blocks = 0
for ni, p in zip(num_input_blocks, predecessor_primitive_ops):
for nj in p.pipeline.config.num_input_blocks:
total_num_input_blocks += ni * nj
return total_num_input_blocks <= max_total_num_input_blocks
return False


Expand Down
16 changes: 8 additions & 8 deletions cubed/tests/test_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,10 +152,12 @@ def fuse_one_level(arr, *, always_fuse=None):
)


def fuse_multiple_levels(*, max_total_source_arrays=4):
def fuse_multiple_levels(*, max_total_source_arrays=4, max_total_num_input_blocks=None):
# use multiple_inputs_optimize_dag to test multiple levels of fusion
return partial(
multiple_inputs_optimize_dag, max_total_source_arrays=max_total_source_arrays
multiple_inputs_optimize_dag,
max_total_source_arrays=max_total_source_arrays,
max_total_num_input_blocks=max_total_num_input_blocks,
)


Expand Down Expand Up @@ -775,9 +777,8 @@ def test_fuse_merge_chunks_unary(spec):
b = xp.negative(a)
c = merge_chunks_new(b, chunks=(3, 2))

# force c to fuse
last_op = sorted(c.plan.dag.nodes())[-1]
opt_fn = fuse_one_level(c, always_fuse=[last_op])
# specify max_total_num_input_blocks to force c to fuse
opt_fn = fuse_multiple_levels(max_total_num_input_blocks=3)

c.visualize(optimize_function=opt_fn)

Expand Down Expand Up @@ -809,9 +810,8 @@ def test_fuse_merge_chunks_binary(spec):
c = xp.add(a, b)
d = merge_chunks_new(c, chunks=(3, 2))

# force d to fuse
last_op = sorted(d.plan.dag.nodes())[-1]
opt_fn = fuse_one_level(d, always_fuse=[last_op])
# specify max_total_num_input_blocks to force d to fuse
opt_fn = fuse_multiple_levels(max_total_num_input_blocks=6)

d.visualize(optimize_function=opt_fn)

Expand Down

0 comments on commit 9ad8fbe

Please sign in to comment.