From 0935c3668a76e14fbd5e65acaf770181f110ba70 Mon Sep 17 00:00:00 2001 From: Kai-Hsun Chen Date: Tue, 29 Oct 2024 11:17:38 -0700 Subject: [PATCH] [core][experimental] Raise an exception if a leaf node is found during compilation (#47757) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Leaf nodes are nodes that are not output nodes and have no downstream nodes. If a leaf node raises an exception, it will not be propagated to the driver. Therefore, this PR raises an exception if a leaf node is found during compilation. Another solution: implicitly add leaf node to MultiOutputNode Currently, the function execute can return multiple CompiledDAGRefs. The UX we want to provide is to implicitly add leaf nodes to the MultiOutputNode but not return the references of the leaf nodes. For example, a MultiOutputNode is containing 3 DAG nodes (2 normal DAG nodes + 1 leaf node). x, y = compiled_dag.execute(input_vals) # We don't return the ref for the leaf node. However, the ref for leaf node will be GC(ed) in execute, and CompiledDAGRef’s del will call get if it was never called which makes execute to be a sync instead of an async operation which is not acceptable. --- python/ray/dag/compiled_dag_node.py | 21 ++++ .../experimental/test_accelerated_dag.py | 106 +++++++++++------- .../tests/experimental/test_collective_dag.py | 5 + .../experimental/test_torch_tensor_dag.py | 4 +- 4 files changed, 94 insertions(+), 42 deletions(-) diff --git a/python/ray/dag/compiled_dag_node.py b/python/ray/dag/compiled_dag_node.py index 2b4bfa68f30a..60a564f9e31e 100644 --- a/python/ray/dag/compiled_dag_node.py +++ b/python/ray/dag/compiled_dag_node.py @@ -973,6 +973,27 @@ def _preprocess(self) -> None: # Add all readers to the NCCL actors of P2P. nccl_actors_p2p.add(downstream_actor_handle) + # Collect all leaf nodes. + leaf_nodes: DAGNode = [] + for idx, task in self.idx_to_task.items(): + if not isinstance(task.dag_node, ClassMethodNode): + continue + if ( + len(task.downstream_task_idxs) == 0 + and not task.dag_node.is_adag_output_node + ): + leaf_nodes.append(task.dag_node) + # Leaf nodes are not allowed because the exception thrown by the leaf + # node will not be propagated to the driver. + if len(leaf_nodes) != 0: + raise ValueError( + "Compiled DAG doesn't support leaf nodes, i.e., nodes that don't have " + "downstream nodes and are not output nodes. There are " + f"{len(leaf_nodes)} leaf nodes in the DAG. Please add the outputs of " + f"{[leaf_node.get_method_name() for leaf_node in leaf_nodes]} to the " + f"the MultiOutputNode." + ) + nccl_actors_p2p = list(nccl_actors_p2p) if None in nccl_actors_p2p: raise ValueError("Driver cannot participate in the NCCL group.") diff --git a/python/ray/dag/tests/experimental/test_accelerated_dag.py b/python/ray/dag/tests/experimental/test_accelerated_dag.py index 38661e1a73ad..e7463d8d2084 100644 --- a/python/ray/dag/tests/experimental/test_accelerated_dag.py +++ b/python/ray/dag/tests/experimental/test_accelerated_dag.py @@ -174,30 +174,6 @@ def test_basic(ray_start_regular): del result -def test_two_returns_first(ray_start_regular): - a = Actor.remote(0) - with InputNode() as i: - o1, o2 = a.return_two.bind(i) - dag = o1 - - compiled_dag = dag.experimental_compile() - for _ in range(3): - res = ray.get(compiled_dag.execute(1)) - assert res == 1 - - -def test_two_returns_second(ray_start_regular): - a = Actor.remote(0) - with InputNode() as i: - o1, o2 = a.return_two.bind(i) - dag = o2 - - compiled_dag = dag.experimental_compile() - for _ in range(3): - res = ray.get(compiled_dag.execute(1)) - assert res == 2 - - @pytest.mark.parametrize("single_fetch", [True, False]) def test_two_returns_one_reader(ray_start_regular, single_fetch): a = Actor.remote(0) @@ -1262,7 +1238,7 @@ def test_compile_twice_with_different_nodes(self, ray_start_regular): with InputNode() as i: branch1 = a.echo.bind(i) branch2 = b.echo.bind(i) - dag = MultiOutputNode([branch1]) + dag = MultiOutputNode([branch1, branch2]) compiled_dag = dag.experimental_compile() compiled_dag.teardown() with pytest.raises( @@ -1270,7 +1246,7 @@ def test_compile_twice_with_different_nodes(self, ray_start_regular): match="The DAG was compiled more than once. The following two " "nodes call `experimental_compile`: ", ): - compiled_dag = branch2.experimental_compile() + branch2.experimental_compile() def test_exceed_max_buffered_results(ray_start_regular): @@ -1782,15 +1758,22 @@ def test_intra_process_channel_with_multi_readers( class TestLeafNode: + """ + Leaf nodes are not allowed right now because the exception thrown by the leaf + node will not be propagated to the driver and silently ignored, which is undesired. + """ + + LEAF_NODE_EXCEPTION_TEMPLATE = ( + "Compiled DAG doesn't support leaf nodes, i.e., nodes that don't have " + "downstream nodes and are not output nodes. There are {num_leaf_nodes} " + "leaf nodes in the DAG. Please add the outputs of" + ) + def test_leaf_node_one_actor(self, ray_start_regular): """ driver -> a.inc | -> a.inc -> driver - - The upper branch (branch 1) is a leaf node, and it will be executed - before the lower `a.inc` task because of the control dependency. Hence, - the result will be [20] because `a.inc` will be executed twice. """ a = Actor.remote(0) with InputNode() as i: @@ -1799,10 +1782,11 @@ def test_leaf_node_one_actor(self, ray_start_regular): branch2 = a.inc.bind(input_data) dag = MultiOutputNode([branch2]) - compiled_dag = dag.experimental_compile() - - ref = compiled_dag.execute(10) - assert ray.get(ref) == [20] + with pytest.raises( + ValueError, + match=TestLeafNode.LEAF_NODE_EXCEPTION_TEMPLATE.format(num_leaf_nodes=1), + ): + dag.experimental_compile() def test_leaf_node_two_actors(self, ray_start_regular): """ @@ -1811,9 +1795,6 @@ def test_leaf_node_two_actors(self, ray_start_regular): | -> b.inc ----> driver | -> a.inc (branch 1) - - The lower branch (branch 1) is a leaf node, and it will be executed - before the upper `a.inc` task because of the control dependency. """ a = Actor.remote(0) b = Actor.remote(100) @@ -1821,10 +1802,55 @@ def test_leaf_node_two_actors(self, ray_start_regular): a.inc.bind(i) # branch1: leaf node branch2 = b.inc.bind(i) dag = MultiOutputNode([a.inc.bind(branch2), b.inc.bind(branch2)]) - compiled_dag = dag.experimental_compile() + with pytest.raises( + ValueError, + match=TestLeafNode.LEAF_NODE_EXCEPTION_TEMPLATE.format(num_leaf_nodes=1), + ): + dag.experimental_compile() + + def test_multi_leaf_nodes(self, ray_start_regular): + """ + driver -> a.inc -> a.inc (branch 1, leaf node) + | | + | -> a.inc -> driver + | + -> a.inc (branch 2, leaf node) + """ + a = Actor.remote(0) + with InputNode() as i: + dag = a.inc.bind(i) + a.inc.bind(dag) # branch1: leaf node + a.inc.bind(i) # branch2: leaf node + dag = MultiOutputNode([a.inc.bind(dag)]) + + with pytest.raises( + ValueError, + match=TestLeafNode.LEAF_NODE_EXCEPTION_TEMPLATE.format(num_leaf_nodes=2), + ): + dag.experimental_compile() - ref = compiled_dag.execute(10) - assert ray.get(ref) == [120, 220] + def test_two_returns_first(self, ray_start_regular): + a = Actor.remote(0) + with InputNode() as i: + o1, o2 = a.return_two.bind(i) + dag = o1 + + with pytest.raises( + ValueError, + match=TestLeafNode.LEAF_NODE_EXCEPTION_TEMPLATE.format(num_leaf_nodes=1), + ): + dag.experimental_compile() + + def test_two_returns_second(self, ray_start_regular): + a = Actor.remote(0) + with InputNode() as i: + o1, o2 = a.return_two.bind(i) + dag = o2 + with pytest.raises( + ValueError, + match=TestLeafNode.LEAF_NODE_EXCEPTION_TEMPLATE.format(num_leaf_nodes=1), + ): + dag.experimental_compile() def test_output_node(ray_start_regular): diff --git a/python/ray/dag/tests/experimental/test_collective_dag.py b/python/ray/dag/tests/experimental/test_collective_dag.py index 680e6fd27dfb..2c9622836632 100644 --- a/python/ray/dag/tests/experimental/test_collective_dag.py +++ b/python/ray/dag/tests/experimental/test_collective_dag.py @@ -404,6 +404,7 @@ def test_comm_deduplicate_p2p_and_collective(ray_start_regular, monkeypatch): dag = workers[1].recv.bind( collectives[0].with_type_hint(TorchTensorType(transport="nccl")) ) + dag = MultiOutputNode([dag, collectives[1]]) compiled_dag, mock_nccl_group_set = check_nccl_group_init( monkeypatch, @@ -435,6 +436,7 @@ def test_custom_comm_deduplicate(ray_start_regular, monkeypatch): dag = workers[0].recv.bind( collectives[1].with_type_hint(TorchTensorType(transport="nccl")) ) + dag = MultiOutputNode([dag, collectives[0]]) compiled_dag, mock_nccl_group_set = check_nccl_group_init( monkeypatch, @@ -453,6 +455,7 @@ def test_custom_comm_deduplicate(ray_start_regular, monkeypatch): dag = workers[0].recv.bind( collectives[1].with_type_hint(TorchTensorType(transport=comm)) ) + dag = MultiOutputNode([dag, collectives[0]]) compiled_dag, mock_nccl_group_set = check_nccl_group_init( monkeypatch, @@ -487,6 +490,7 @@ def test_custom_comm_init_teardown(ray_start_regular, monkeypatch): dag = workers[0].recv.bind( allreduce[1].with_type_hint(TorchTensorType(transport=comm)) ) + dag = MultiOutputNode([dag, allreduce[0]]) compiled_dag, mock_nccl_group_set = check_nccl_group_init( monkeypatch, @@ -508,6 +512,7 @@ def test_custom_comm_init_teardown(ray_start_regular, monkeypatch): dag = workers[0].recv.bind( allreduce2[1].with_type_hint(TorchTensorType(transport=comm_3)) ) + dag = MultiOutputNode([dag, allreduce2[0]]) compiled_dag, mock_nccl_group_set = check_nccl_group_init( monkeypatch, diff --git a/python/ray/dag/tests/experimental/test_torch_tensor_dag.py b/python/ray/dag/tests/experimental/test_torch_tensor_dag.py index 5d4f06e528e7..4b4214e89d3c 100644 --- a/python/ray/dag/tests/experimental/test_torch_tensor_dag.py +++ b/python/ray/dag/tests/experimental/test_torch_tensor_dag.py @@ -864,7 +864,7 @@ def test_torch_tensor_nccl_all_reduce_get_partial(ray_start_regular): collectives = collective.allreduce.bind(computes, ReduceOp.SUM) recv = workers[0].recv.bind(collectives[0]) tensor = workers[1].recv_tensor.bind(collectives[0]) - dag = MultiOutputNode([recv, tensor]) + dag = MultiOutputNode([recv, tensor, collectives[1]]) compiled_dag = dag.experimental_compile() @@ -873,7 +873,7 @@ def test_torch_tensor_nccl_all_reduce_get_partial(ray_start_regular): [(shape, dtype, i + idx + 1) for idx in range(num_workers)] ) result = ray.get(ref) - metadata, tensor = result + metadata, tensor, _ = result reduced_val = sum(i + idx + 1 for idx in range(num_workers)) assert metadata == (reduced_val, shape, dtype) tensor = tensor.to("cpu")