Skip to content

Commit

Permalink
[core][experimental] Raise an exception if a leaf node is found durin…
Browse files Browse the repository at this point in the history
…g compilation (#47757)

Leaf nodes are nodes that are not output nodes and have no downstream nodes. If a leaf node raises an exception, it will not be propagated to the driver. Therefore, this PR raises an exception if a leaf node is found during compilation.

Another solution: implicitly add leaf node to MultiOutputNode
Currently, the function execute can return multiple CompiledDAGRefs. The UX we want to provide is to implicitly add leaf nodes to the MultiOutputNode but not return the references of the leaf nodes. For example, a MultiOutputNode is containing 3 DAG nodes (2 normal DAG nodes + 1 leaf node).

x, y = compiled_dag.execute(input_vals) # We don't return the ref for the leaf node.
However, the ref for leaf node will be GC(ed) in execute, and CompiledDAGRef’s del will call get if it was never called which makes execute to be a sync instead of an async operation which is not acceptable.
  • Loading branch information
kevin85421 authored Oct 29, 2024
1 parent 5e252d7 commit 0935c36
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 42 deletions.
21 changes: 21 additions & 0 deletions python/ray/dag/compiled_dag_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -973,6 +973,27 @@ def _preprocess(self) -> None:
# Add all readers to the NCCL actors of P2P.
nccl_actors_p2p.add(downstream_actor_handle)

# Collect all leaf nodes.
leaf_nodes: DAGNode = []
for idx, task in self.idx_to_task.items():
if not isinstance(task.dag_node, ClassMethodNode):
continue
if (
len(task.downstream_task_idxs) == 0
and not task.dag_node.is_adag_output_node
):
leaf_nodes.append(task.dag_node)
# Leaf nodes are not allowed because the exception thrown by the leaf
# node will not be propagated to the driver.
if len(leaf_nodes) != 0:
raise ValueError(
"Compiled DAG doesn't support leaf nodes, i.e., nodes that don't have "
"downstream nodes and are not output nodes. There are "
f"{len(leaf_nodes)} leaf nodes in the DAG. Please add the outputs of "
f"{[leaf_node.get_method_name() for leaf_node in leaf_nodes]} to the "
f"the MultiOutputNode."
)

nccl_actors_p2p = list(nccl_actors_p2p)
if None in nccl_actors_p2p:
raise ValueError("Driver cannot participate in the NCCL group.")
Expand Down
106 changes: 66 additions & 40 deletions python/ray/dag/tests/experimental/test_accelerated_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,30 +174,6 @@ def test_basic(ray_start_regular):
del result


def test_two_returns_first(ray_start_regular):
a = Actor.remote(0)
with InputNode() as i:
o1, o2 = a.return_two.bind(i)
dag = o1

compiled_dag = dag.experimental_compile()
for _ in range(3):
res = ray.get(compiled_dag.execute(1))
assert res == 1


def test_two_returns_second(ray_start_regular):
a = Actor.remote(0)
with InputNode() as i:
o1, o2 = a.return_two.bind(i)
dag = o2

compiled_dag = dag.experimental_compile()
for _ in range(3):
res = ray.get(compiled_dag.execute(1))
assert res == 2


@pytest.mark.parametrize("single_fetch", [True, False])
def test_two_returns_one_reader(ray_start_regular, single_fetch):
a = Actor.remote(0)
Expand Down Expand Up @@ -1262,15 +1238,15 @@ def test_compile_twice_with_different_nodes(self, ray_start_regular):
with InputNode() as i:
branch1 = a.echo.bind(i)
branch2 = b.echo.bind(i)
dag = MultiOutputNode([branch1])
dag = MultiOutputNode([branch1, branch2])
compiled_dag = dag.experimental_compile()
compiled_dag.teardown()
with pytest.raises(
ValueError,
match="The DAG was compiled more than once. The following two "
"nodes call `experimental_compile`: ",
):
compiled_dag = branch2.experimental_compile()
branch2.experimental_compile()


def test_exceed_max_buffered_results(ray_start_regular):
Expand Down Expand Up @@ -1782,15 +1758,22 @@ def test_intra_process_channel_with_multi_readers(


class TestLeafNode:
"""
Leaf nodes are not allowed right now because the exception thrown by the leaf
node will not be propagated to the driver and silently ignored, which is undesired.
"""

LEAF_NODE_EXCEPTION_TEMPLATE = (
"Compiled DAG doesn't support leaf nodes, i.e., nodes that don't have "
"downstream nodes and are not output nodes. There are {num_leaf_nodes} "
"leaf nodes in the DAG. Please add the outputs of"
)

def test_leaf_node_one_actor(self, ray_start_regular):
"""
driver -> a.inc
|
-> a.inc -> driver
The upper branch (branch 1) is a leaf node, and it will be executed
before the lower `a.inc` task because of the control dependency. Hence,
the result will be [20] because `a.inc` will be executed twice.
"""
a = Actor.remote(0)
with InputNode() as i:
Expand All @@ -1799,10 +1782,11 @@ def test_leaf_node_one_actor(self, ray_start_regular):
branch2 = a.inc.bind(input_data)
dag = MultiOutputNode([branch2])

compiled_dag = dag.experimental_compile()

ref = compiled_dag.execute(10)
assert ray.get(ref) == [20]
with pytest.raises(
ValueError,
match=TestLeafNode.LEAF_NODE_EXCEPTION_TEMPLATE.format(num_leaf_nodes=1),
):
dag.experimental_compile()

def test_leaf_node_two_actors(self, ray_start_regular):
"""
Expand All @@ -1811,20 +1795,62 @@ def test_leaf_node_two_actors(self, ray_start_regular):
| -> b.inc ----> driver
|
-> a.inc (branch 1)
The lower branch (branch 1) is a leaf node, and it will be executed
before the upper `a.inc` task because of the control dependency.
"""
a = Actor.remote(0)
b = Actor.remote(100)
with InputNode() as i:
a.inc.bind(i) # branch1: leaf node
branch2 = b.inc.bind(i)
dag = MultiOutputNode([a.inc.bind(branch2), b.inc.bind(branch2)])
compiled_dag = dag.experimental_compile()
with pytest.raises(
ValueError,
match=TestLeafNode.LEAF_NODE_EXCEPTION_TEMPLATE.format(num_leaf_nodes=1),
):
dag.experimental_compile()

def test_multi_leaf_nodes(self, ray_start_regular):
"""
driver -> a.inc -> a.inc (branch 1, leaf node)
| |
| -> a.inc -> driver
|
-> a.inc (branch 2, leaf node)
"""
a = Actor.remote(0)
with InputNode() as i:
dag = a.inc.bind(i)
a.inc.bind(dag) # branch1: leaf node
a.inc.bind(i) # branch2: leaf node
dag = MultiOutputNode([a.inc.bind(dag)])

with pytest.raises(
ValueError,
match=TestLeafNode.LEAF_NODE_EXCEPTION_TEMPLATE.format(num_leaf_nodes=2),
):
dag.experimental_compile()

ref = compiled_dag.execute(10)
assert ray.get(ref) == [120, 220]
def test_two_returns_first(self, ray_start_regular):
a = Actor.remote(0)
with InputNode() as i:
o1, o2 = a.return_two.bind(i)
dag = o1

with pytest.raises(
ValueError,
match=TestLeafNode.LEAF_NODE_EXCEPTION_TEMPLATE.format(num_leaf_nodes=1),
):
dag.experimental_compile()

def test_two_returns_second(self, ray_start_regular):
a = Actor.remote(0)
with InputNode() as i:
o1, o2 = a.return_two.bind(i)
dag = o2
with pytest.raises(
ValueError,
match=TestLeafNode.LEAF_NODE_EXCEPTION_TEMPLATE.format(num_leaf_nodes=1),
):
dag.experimental_compile()


def test_output_node(ray_start_regular):
Expand Down
5 changes: 5 additions & 0 deletions python/ray/dag/tests/experimental/test_collective_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,7 @@ def test_comm_deduplicate_p2p_and_collective(ray_start_regular, monkeypatch):
dag = workers[1].recv.bind(
collectives[0].with_type_hint(TorchTensorType(transport="nccl"))
)
dag = MultiOutputNode([dag, collectives[1]])

compiled_dag, mock_nccl_group_set = check_nccl_group_init(
monkeypatch,
Expand Down Expand Up @@ -435,6 +436,7 @@ def test_custom_comm_deduplicate(ray_start_regular, monkeypatch):
dag = workers[0].recv.bind(
collectives[1].with_type_hint(TorchTensorType(transport="nccl"))
)
dag = MultiOutputNode([dag, collectives[0]])

compiled_dag, mock_nccl_group_set = check_nccl_group_init(
monkeypatch,
Expand All @@ -453,6 +455,7 @@ def test_custom_comm_deduplicate(ray_start_regular, monkeypatch):
dag = workers[0].recv.bind(
collectives[1].with_type_hint(TorchTensorType(transport=comm))
)
dag = MultiOutputNode([dag, collectives[0]])

compiled_dag, mock_nccl_group_set = check_nccl_group_init(
monkeypatch,
Expand Down Expand Up @@ -487,6 +490,7 @@ def test_custom_comm_init_teardown(ray_start_regular, monkeypatch):
dag = workers[0].recv.bind(
allreduce[1].with_type_hint(TorchTensorType(transport=comm))
)
dag = MultiOutputNode([dag, allreduce[0]])

compiled_dag, mock_nccl_group_set = check_nccl_group_init(
monkeypatch,
Expand All @@ -508,6 +512,7 @@ def test_custom_comm_init_teardown(ray_start_regular, monkeypatch):
dag = workers[0].recv.bind(
allreduce2[1].with_type_hint(TorchTensorType(transport=comm_3))
)
dag = MultiOutputNode([dag, allreduce2[0]])

compiled_dag, mock_nccl_group_set = check_nccl_group_init(
monkeypatch,
Expand Down
4 changes: 2 additions & 2 deletions python/ray/dag/tests/experimental/test_torch_tensor_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -864,7 +864,7 @@ def test_torch_tensor_nccl_all_reduce_get_partial(ray_start_regular):
collectives = collective.allreduce.bind(computes, ReduceOp.SUM)
recv = workers[0].recv.bind(collectives[0])
tensor = workers[1].recv_tensor.bind(collectives[0])
dag = MultiOutputNode([recv, tensor])
dag = MultiOutputNode([recv, tensor, collectives[1]])

compiled_dag = dag.experimental_compile()

Expand All @@ -873,7 +873,7 @@ def test_torch_tensor_nccl_all_reduce_get_partial(ray_start_regular):
[(shape, dtype, i + idx + 1) for idx in range(num_workers)]
)
result = ray.get(ref)
metadata, tensor = result
metadata, tensor, _ = result
reduced_val = sum(i + idx + 1 for idx in range(num_workers))
assert metadata == (reduced_val, shape, dtype)
tensor = tensor.to("cpu")
Expand Down

0 comments on commit 0935c36

Please sign in to comment.