Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[aDAG] Support all reduce collective in aDAG #47621

Merged
merged 96 commits into from
Oct 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
96 commits
Select commit Hold shift + click to select a range
dd8afbe
(WIP) chore: Add input index in reader
dengwxn Aug 8, 2024
d957669
chore: Clean up
dengwxn Aug 8, 2024
0df4191
(WIP) chore: Update channel allocation for TaskReturnNode
dengwxn Aug 8, 2024
952daa0
(WIP) chore: Add tests
dengwxn Aug 9, 2024
3479587
chore: Return TaskReturnNode only when num_returns > 1
dengwxn Aug 9, 2024
3312712
chore: Add test for two returns three actors
dengwxn Aug 9, 2024
97faeb7
chore: Clean up
dengwxn Aug 9, 2024
197e65e
chore: Adjust input_idxs
dengwxn Aug 9, 2024
a48fecd
(WIP) chore: Allocate an output channel for each output value
dengwxn Aug 16, 2024
16e0323
(WIP) chore: Remove legacy dependencies
dengwxn Aug 16, 2024
d12f28b
(WIP) chore: Remove legacy dependencies
dengwxn Aug 16, 2024
00a6de7
chore: Update async writer
dengwxn Aug 16, 2024
99a42f6
chore: Update comments
dengwxn Aug 16, 2024
40c6bb2
chore: Remove legacy
dengwxn Aug 16, 2024
6ef6e60
chore: Update tests
dengwxn Aug 22, 2024
6772b3a
chore: Revert read_by_multi_output_node
dengwxn Aug 22, 2024
b35e632
chore: Add comment
dengwxn Aug 22, 2024
f2004a3
chore: Rename output task to upstream/downstream task
dengwxn Aug 22, 2024
a65b053
chore: Update tests
dengwxn Aug 22, 2024
db1d796
chore: Update private fields
dengwxn Aug 22, 2024
ce4694e
Merge branch 'master' into master
dengwxn Aug 22, 2024
d2d50f3
feat: Merge ClassMethodOutputNode into ClassMethodNode
dengwxn Aug 23, 2024
e64a3da
Merge branch 'master' into master
dengwxn Aug 23, 2024
007f656
chore: Add tests and unify internal fields
dengwxn Aug 28, 2024
c8e6a5b
chore: Adjust comments and tests
dengwxn Aug 29, 2024
c729e68
chore: Format code
dengwxn Aug 29, 2024
39b1953
chore: Merge master
dengwxn Aug 29, 2024
39a76c5
chore: Fix check of readers for each output
dengwxn Aug 29, 2024
a84ad8b
chore: Format code
dengwxn Aug 29, 2024
ce472ae
chore: Format code
dengwxn Aug 29, 2024
d04e687
chore: Format code
dengwxn Aug 29, 2024
315a192
chore: Fix typo
dengwxn Aug 30, 2024
f050638
chore: Fix typo
dengwxn Aug 30, 2024
8fc7cd0
chore: Fix typo
dengwxn Aug 30, 2024
a88defb
chore: Fix typo
dengwxn Aug 30, 2024
7e3eb68
chore: Merge master
dengwxn Aug 30, 2024
1d08bd7
chore: Format code
dengwxn Aug 30, 2024
28bc144
chore: Fix output channels
dengwxn Aug 30, 2024
81915af
Merge branch 'master' of https://github.com/ray-project/ray
dengwxn Aug 31, 2024
dbb85dc
chore: Fix reader_handles_set
dengwxn Aug 31, 2024
9af6490
Merge branch 'master' of github.com:dengwxn/ray
dengwxn Aug 31, 2024
e463a61
Merge branch 'master' of https://github.com/ray-project/ray
dengwxn Aug 31, 2024
b8173b7
chore: Fix mock class method call in tests
dengwxn Sep 1, 2024
d76a106
Merge branch 'master' into master
dengwxn Sep 2, 2024
39a4452
Merge branch 'ray-project:master' into master
dengwxn Sep 2, 2024
dc2586d
Merge branch 'ray-project:master' into master
dengwxn Sep 5, 2024
5a83526
chore: Boot waterfront
dengwxn Sep 5, 2024
f10d357
Merge branch 'master' of github.com:dengwxn/ray
dengwxn Sep 5, 2024
ef71af5
Merge branch 'master' of https://github.com/ray-project/ray
dengwxn Sep 18, 2024
2a476de
chore: Clean up
dengwxn Sep 18, 2024
6ae63c8
refactor: Merge ccar-0905
dengwxn Sep 24, 2024
76f57c9
refactor: Merge ccar-0905
dengwxn Sep 25, 2024
8292346
Merge branch 'master' into ccar-0905
dengwxn Sep 25, 2024
6d262fa
chore: Fix type check
dengwxn Sep 25, 2024
204581e
refactor: Fix remove in edges
dengwxn Sep 25, 2024
88c7dc6
chore: Revert file
dengwxn Sep 25, 2024
b3e1742
refactor: Fix api
dengwxn Sep 26, 2024
7236505
refactor: Fix api
dengwxn Sep 27, 2024
283869b
feat: Merge upstream
dengwxn Oct 13, 2024
6e59d3d
(WIP) refactor: Inherit ClassMethodNode for CollectiveOutputNode
dengwxn Oct 13, 2024
c94ffa9
Merge branch 'master' into ccar-0905
dengwxn Oct 14, 2024
a045889
refactor: Inherit from ClassMethodNode
dengwxn Oct 14, 2024
c73efd6
test: Change size
dengwxn Oct 15, 2024
7be6954
refactor: Code review
dengwxn Oct 16, 2024
b58d27e
refactor: Unify update candidate nodes
dengwxn Oct 16, 2024
6e81470
fix: Union two sets of nccl group ids
dengwxn Oct 17, 2024
bd7ba01
merge: Upstream master
dengwxn Oct 17, 2024
c013bc7
test: Reduce op values
dengwxn Oct 17, 2024
8e570c9
refactor: Fix reduce op values
dengwxn Oct 17, 2024
406dcd0
fix: API annotations
dengwxn Oct 17, 2024
f90270b
merge: Polish tests
dengwxn Oct 17, 2024
f41904a
chore: Polish tests
dengwxn Oct 17, 2024
789b9bc
test: Check num
dengwxn Oct 17, 2024
7e05f85
test: Remove non-tensor input case
dengwxn Oct 17, 2024
a0f1381
test: Remove allocate tensor case
dengwxn Oct 17, 2024
00949bd
chore: Polish tests
AndyUB Oct 17, 2024
6c80c59
merge: Upstream
AndyUB Oct 17, 2024
6e71c93
refactor: Test separate types
dengwxn Oct 18, 2024
ccea7a3
refactor: Test original types
dengwxn Oct 18, 2024
f634bbb
refactor: Use separate types by if-else
dengwxn Oct 18, 2024
ccbf68b
refactor: Use separate types by if-else
dengwxn Oct 18, 2024
70ea96d
refactor: Convert to ray op
dengwxn Oct 18, 2024
4b77907
revert: Skip ray types
dengwxn Oct 18, 2024
c5e2c20
revert: Skip ray types
dengwxn Oct 18, 2024
5e94216
merge: Upstream
AndyUB Oct 18, 2024
09c6709
chore: Cleanup tests
AndyUB Oct 18, 2024
b8d1891
chore: Code review
dengwxn Oct 19, 2024
b3770a6
chore: Simplify tests
AndyUB Oct 19, 2024
c249ec3
merge: Upstream
AndyUB Oct 19, 2024
ff7f720
chore: Polish tests
AndyUB Oct 19, 2024
d8c85fb
Merge pull request #15 from AndyUB/test-1017
dengwxn Oct 20, 2024
0486914
refactor: Polish tests
dengwxn Oct 20, 2024
6d0db72
chore: Format
dengwxn Oct 20, 2024
2d59839
Merge branch 'master' into ccar-0905
dengwxn Oct 20, 2024
860139c
test: Mock gpus
dengwxn Oct 20, 2024
7df480f
merge: Upstream branch
dengwxn Oct 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion python/ray/dag/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ py_test_module_list(
"tests/experimental/test_detect_deadlock_dag.py",
"tests/experimental/test_multi_node_dag.py",
"tests/experimental/test_torch_tensor_dag.py",
"tests/experimental/test_collective_dag.py",
"tests/experimental/test_execution_schedule.py",
],
tags = [
Expand Down Expand Up @@ -132,7 +133,7 @@ py_test_module_list(

py_test(
name = "test_torch_tensor_dag_gpu",
size = "medium",
size = "large",
srcs = [
"tests/experimental/test_torch_tensor_dag.py",
],
Expand Down
8 changes: 8 additions & 0 deletions python/ray/dag/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
ClassNode,
ClassMethodNode,
)
from ray.dag.collective_node import CollectiveOutputNode
from ray.dag.input_node import (
InputNode,
InputAttributeNode,
Expand All @@ -13,6 +14,9 @@
from ray.dag.constants import (
PARENT_CLASS_NODE_KEY,
PREV_CLASS_METHOD_CALL_KEY,
BIND_INDEX_KEY,
IS_CLASS_METHOD_OUTPUT_KEY,
COLLECTIVE_OPERATION_KEY,
DAGNODE_TYPE_KEY,
)
from ray.dag.vis_utils import plot
Expand All @@ -21,13 +25,17 @@
__all__ = [
"ClassNode",
"ClassMethodNode",
"CollectiveOutputNode",
"DAGNode",
"FunctionNode",
"InputNode",
"InputAttributeNode",
"DAGInputData",
"PARENT_CLASS_NODE_KEY",
"PREV_CLASS_METHOD_CALL_KEY",
"BIND_INDEX_KEY",
"IS_CLASS_METHOD_OUTPUT_KEY",
"COLLECTIVE_OPERATION_KEY",
"DAGNODE_TYPE_KEY",
"plot",
"MultiOutputNode",
Expand Down
192 changes: 192 additions & 0 deletions python/ray/dag/collective_node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
from typing import Any, Dict, List, Union, Tuple, Optional, TYPE_CHECKING

if TYPE_CHECKING:
import torch

import ray
from ray.dag import (
DAGNode,
ClassMethodNode,
)
from ray.dag.constants import COLLECTIVE_OPERATION_KEY
from ray.experimental.channel import ChannelContext
from ray.experimental.channel.torch_tensor_nccl_channel import _init_nccl_group
from ray.experimental.channel.torch_tensor_type import GPUCommunicator, TorchTensorType
from ray.experimental.util.types import _CollectiveOp, ReduceOp
from ray.util.annotations import DeveloperAPI


class _CollectiveOperation:
"""
Represent metadata for a NCCL collective operation.
dengwxn marked this conversation as resolved.
Show resolved Hide resolved

Args:
input_nodes: A list of input nodes to the collective operation.
op: The collective operation to perform.
transport: The transport to use for the collective operation.

Requirements:
1. Input nodes are unique.
2. Actor handles are unique.
3. Actor handles match the custom NCCL group if specified.
"""

def __init__(
self,
input_nodes: List[DAGNode],
op: _CollectiveOp,
transport: Optional[Union[str, GPUCommunicator]] = None,
):
self._input_nodes: List[DAGNode] = input_nodes
if len(self._input_nodes) == 0:
raise ValueError("Expected input nodes for a collective operation")
if len(set(self._input_nodes)) != len(self._input_nodes):
raise ValueError("Expected unique input nodes for a collective operation")

self._actor_handles: List["ray.actor.ActorHandle"] = []
for input_node in self._input_nodes:
actor_handle = input_node._get_actor_handle()
if actor_handle is None:
raise ValueError("Expected an actor handle from the input node")
self._actor_handles.append(actor_handle)
if len(set(self._actor_handles)) != len(self._actor_handles):
invalid_input_nodes = [
input_node
for input_node in self._input_nodes
if self._actor_handles.count(input_node._get_actor_handle()) > 1
]
raise ValueError(
"Expected unique actor handles for a collective operation, "
"but found duplicate actor handles from input nodes: "
f"{invalid_input_nodes}"
)

self._op = op
if not isinstance(self._op, ReduceOp):
raise NotImplementedError("Only ReduceOp is implemented")
if transport is None:
transport = TorchTensorType.NCCL
self._type_hint = TorchTensorType(transport=transport, _direct_return=True)
if isinstance(transport, GPUCommunicator):
if set(transport.get_actor_handles()) != set(self._actor_handles):
raise ValueError(
"Expected actor handles to match the custom NCCL group"
)

def __str__(self) -> str:
return (
f"CollectiveGroup("
f"_input_nodes={self._input_nodes}, "
f"_actor_handles={self._actor_handles}, "
f"_op={self._op}, "
f"_type_hint={self._type_hint})"
)

@property
def actor_handles(self) -> List["ray.actor.ActorHandle"]:
return self._actor_handles
dengwxn marked this conversation as resolved.
Show resolved Hide resolved

@property
def type_hint(self) -> TorchTensorType:
return self._type_hint

def init_nccl_group(self, nccl_group_id: Optional[str] = None) -> str:
"""
Initialize the NCCL group if it has not been initialized yet. If `nccl_group_id`
is provided, it means the NCCL group has already been initialized.
"""
type_hint = self._type_hint
if type_hint.nccl_group_id is not None:
return type_hint.nccl_group_id
if nccl_group_id is None:
nccl_group_id = _init_nccl_group(
self._actor_handles, type_hint.get_custom_nccl_group()
)
type_hint.set_nccl_group_id(nccl_group_id)
return nccl_group_id

def get_nccl_group(self) -> GPUCommunicator:
if self._type_hint.nccl_group_id is not None:
ctx = ChannelContext.get_current()
nccl_group = ctx.nccl_groups[self._type_hint.nccl_group_id]
elif self._type_hint.get_custom_nccl_group() is not None:
nccl_group = self._type_hint.get_custom_nccl_group()
else:
raise ValueError("Expected a NCCL group")
return nccl_group

def execute(self, send_buf: "torch.Tensor") -> "torch.Tensor":
"""
Call the collective operation on the input tensor. An output tensor is
allocated and returned.
"""
import torch

if not isinstance(send_buf, torch.Tensor):
raise ValueError("Expected a torch tensor")
nccl_group = self.get_nccl_group()
recv_buf = torch.empty_like(send_buf)
nccl_group.allreduce(send_buf, recv_buf, self._op)
return recv_buf


@DeveloperAPI
class CollectiveOutputNode(ClassMethodNode):
"""Represent an output node from a NCCL collective operation in a Ray DAG."""

def __init__(
self,
method_name: str,
method_args: Tuple[
DAGNode,
],
method_kwargs: Dict[str, Any],
method_options: Dict[str, Any],
other_args_to_resolve: Dict[str, Any],
):
# Parse the input node.
if not (
isinstance(method_args, tuple)
and len(method_args) == 1
and isinstance(method_args[0], DAGNode)
):
raise ValueError("Expected a single input node")
self._input_node = method_args[0]
# Parse the collective operation.
self._collective_op: _CollectiveOperation = other_args_to_resolve.get(
COLLECTIVE_OPERATION_KEY, None
)
if self._collective_op is None:
raise ValueError("Expected a collective operation")

super().__init__(
method_name,
method_args,
method_kwargs,
method_options,
other_args_to_resolve,
)

def _copy_impl(
self,
new_args: List[Any],
new_kwargs: Dict[str, Any],
new_options: Dict[str, Any],
new_other_args_to_resolve: Dict[str, Any],
):
return CollectiveOutputNode(
self._method_name,
new_args,
new_kwargs,
new_options,
other_args_to_resolve=new_other_args_to_resolve,
)

def _execute_impl(self, *args, **kwargs):
raise NotImplementedError(
"CollectiveOutputNode is only supported with dag.experimental_compile()"
)

@property
def collective_op(self) -> _CollectiveOperation:
return self._collective_op
Loading