Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[aDAG] Support all reduce collective in aDAG #47621

Merged
merged 96 commits into from
Oct 21, 2024
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
96 commits
Select commit Hold shift + click to select a range
dd8afbe
(WIP) chore: Add input index in reader
dengwxn Aug 8, 2024
d957669
chore: Clean up
dengwxn Aug 8, 2024
0df4191
(WIP) chore: Update channel allocation for TaskReturnNode
dengwxn Aug 8, 2024
952daa0
(WIP) chore: Add tests
dengwxn Aug 9, 2024
3479587
chore: Return TaskReturnNode only when num_returns > 1
dengwxn Aug 9, 2024
3312712
chore: Add test for two returns three actors
dengwxn Aug 9, 2024
97faeb7
chore: Clean up
dengwxn Aug 9, 2024
197e65e
chore: Adjust input_idxs
dengwxn Aug 9, 2024
a48fecd
(WIP) chore: Allocate an output channel for each output value
dengwxn Aug 16, 2024
16e0323
(WIP) chore: Remove legacy dependencies
dengwxn Aug 16, 2024
d12f28b
(WIP) chore: Remove legacy dependencies
dengwxn Aug 16, 2024
00a6de7
chore: Update async writer
dengwxn Aug 16, 2024
99a42f6
chore: Update comments
dengwxn Aug 16, 2024
40c6bb2
chore: Remove legacy
dengwxn Aug 16, 2024
6ef6e60
chore: Update tests
dengwxn Aug 22, 2024
6772b3a
chore: Revert read_by_multi_output_node
dengwxn Aug 22, 2024
b35e632
chore: Add comment
dengwxn Aug 22, 2024
f2004a3
chore: Rename output task to upstream/downstream task
dengwxn Aug 22, 2024
a65b053
chore: Update tests
dengwxn Aug 22, 2024
db1d796
chore: Update private fields
dengwxn Aug 22, 2024
ce4694e
Merge branch 'master' into master
dengwxn Aug 22, 2024
d2d50f3
feat: Merge ClassMethodOutputNode into ClassMethodNode
dengwxn Aug 23, 2024
e64a3da
Merge branch 'master' into master
dengwxn Aug 23, 2024
007f656
chore: Add tests and unify internal fields
dengwxn Aug 28, 2024
c8e6a5b
chore: Adjust comments and tests
dengwxn Aug 29, 2024
c729e68
chore: Format code
dengwxn Aug 29, 2024
39b1953
chore: Merge master
dengwxn Aug 29, 2024
39a76c5
chore: Fix check of readers for each output
dengwxn Aug 29, 2024
a84ad8b
chore: Format code
dengwxn Aug 29, 2024
ce472ae
chore: Format code
dengwxn Aug 29, 2024
d04e687
chore: Format code
dengwxn Aug 29, 2024
315a192
chore: Fix typo
dengwxn Aug 30, 2024
f050638
chore: Fix typo
dengwxn Aug 30, 2024
8fc7cd0
chore: Fix typo
dengwxn Aug 30, 2024
a88defb
chore: Fix typo
dengwxn Aug 30, 2024
7e3eb68
chore: Merge master
dengwxn Aug 30, 2024
1d08bd7
chore: Format code
dengwxn Aug 30, 2024
28bc144
chore: Fix output channels
dengwxn Aug 30, 2024
81915af
Merge branch 'master' of https://github.com/ray-project/ray
dengwxn Aug 31, 2024
dbb85dc
chore: Fix reader_handles_set
dengwxn Aug 31, 2024
9af6490
Merge branch 'master' of github.com:dengwxn/ray
dengwxn Aug 31, 2024
e463a61
Merge branch 'master' of https://github.com/ray-project/ray
dengwxn Aug 31, 2024
b8173b7
chore: Fix mock class method call in tests
dengwxn Sep 1, 2024
d76a106
Merge branch 'master' into master
dengwxn Sep 2, 2024
39a4452
Merge branch 'ray-project:master' into master
dengwxn Sep 2, 2024
dc2586d
Merge branch 'ray-project:master' into master
dengwxn Sep 5, 2024
5a83526
chore: Boot waterfront
dengwxn Sep 5, 2024
f10d357
Merge branch 'master' of github.com:dengwxn/ray
dengwxn Sep 5, 2024
ef71af5
Merge branch 'master' of https://github.com/ray-project/ray
dengwxn Sep 18, 2024
2a476de
chore: Clean up
dengwxn Sep 18, 2024
6ae63c8
refactor: Merge ccar-0905
dengwxn Sep 24, 2024
76f57c9
refactor: Merge ccar-0905
dengwxn Sep 25, 2024
8292346
Merge branch 'master' into ccar-0905
dengwxn Sep 25, 2024
6d262fa
chore: Fix type check
dengwxn Sep 25, 2024
204581e
refactor: Fix remove in edges
dengwxn Sep 25, 2024
88c7dc6
chore: Revert file
dengwxn Sep 25, 2024
b3e1742
refactor: Fix api
dengwxn Sep 26, 2024
7236505
refactor: Fix api
dengwxn Sep 27, 2024
283869b
feat: Merge upstream
dengwxn Oct 13, 2024
6e59d3d
(WIP) refactor: Inherit ClassMethodNode for CollectiveOutputNode
dengwxn Oct 13, 2024
c94ffa9
Merge branch 'master' into ccar-0905
dengwxn Oct 14, 2024
a045889
refactor: Inherit from ClassMethodNode
dengwxn Oct 14, 2024
c73efd6
test: Change size
dengwxn Oct 15, 2024
7be6954
refactor: Code review
dengwxn Oct 16, 2024
b58d27e
refactor: Unify update candidate nodes
dengwxn Oct 16, 2024
6e81470
fix: Union two sets of nccl group ids
dengwxn Oct 17, 2024
bd7ba01
merge: Upstream master
dengwxn Oct 17, 2024
c013bc7
test: Reduce op values
dengwxn Oct 17, 2024
8e570c9
refactor: Fix reduce op values
dengwxn Oct 17, 2024
406dcd0
fix: API annotations
dengwxn Oct 17, 2024
f90270b
merge: Polish tests
dengwxn Oct 17, 2024
f41904a
chore: Polish tests
dengwxn Oct 17, 2024
789b9bc
test: Check num
dengwxn Oct 17, 2024
7e05f85
test: Remove non-tensor input case
dengwxn Oct 17, 2024
a0f1381
test: Remove allocate tensor case
dengwxn Oct 17, 2024
00949bd
chore: Polish tests
AndyUB Oct 17, 2024
6c80c59
merge: Upstream
AndyUB Oct 17, 2024
6e71c93
refactor: Test separate types
dengwxn Oct 18, 2024
ccea7a3
refactor: Test original types
dengwxn Oct 18, 2024
f634bbb
refactor: Use separate types by if-else
dengwxn Oct 18, 2024
ccbf68b
refactor: Use separate types by if-else
dengwxn Oct 18, 2024
70ea96d
refactor: Convert to ray op
dengwxn Oct 18, 2024
4b77907
revert: Skip ray types
dengwxn Oct 18, 2024
c5e2c20
revert: Skip ray types
dengwxn Oct 18, 2024
5e94216
merge: Upstream
AndyUB Oct 18, 2024
09c6709
chore: Cleanup tests
AndyUB Oct 18, 2024
b8d1891
chore: Code review
dengwxn Oct 19, 2024
b3770a6
chore: Simplify tests
AndyUB Oct 19, 2024
c249ec3
merge: Upstream
AndyUB Oct 19, 2024
ff7f720
chore: Polish tests
AndyUB Oct 19, 2024
d8c85fb
Merge pull request #15 from AndyUB/test-1017
dengwxn Oct 20, 2024
0486914
refactor: Polish tests
dengwxn Oct 20, 2024
6d0db72
chore: Format
dengwxn Oct 20, 2024
2d59839
Merge branch 'master' into ccar-0905
dengwxn Oct 20, 2024
860139c
test: Mock gpus
dengwxn Oct 20, 2024
7df480f
merge: Upstream branch
dengwxn Oct 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/ray/dag/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ py_test_module_list(
"tests/experimental/test_detect_deadlock_dag.py",
"tests/experimental/test_multi_node_dag.py",
"tests/experimental/test_torch_tensor_dag.py",
"tests/experimental/test_collective_dag.py",
"tests/experimental/test_execution_schedule.py",
],
tags = [
Expand Down
4 changes: 2 additions & 2 deletions python/ray/dag/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
PREV_CLASS_METHOD_CALL_KEY,
BIND_INDEX_KEY,
IS_CLASS_METHOD_OUTPUT_KEY,
COLLECTIVE_GROUP_KEY,
COLLECTIVE_OPERATION_KEY,
DAGNODE_TYPE_KEY,
)
from ray.dag.vis_utils import plot
Expand All @@ -35,7 +35,7 @@
"PREV_CLASS_METHOD_CALL_KEY",
"BIND_INDEX_KEY",
"IS_CLASS_METHOD_OUTPUT_KEY",
"COLLECTIVE_GROUP_KEY",
"COLLECTIVE_OPERATION_KEY",
"DAGNODE_TYPE_KEY",
"plot",
"MultiOutputNode",
Expand Down
40 changes: 21 additions & 19 deletions python/ray/dag/collective_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,27 @@
DAGNode,
ClassMethodNode,
)
from ray.dag.constants import COLLECTIVE_GROUP_KEY
from ray.util.annotations import DeveloperAPI
from ray.dag.constants import COLLECTIVE_OPERATION_KEY
from ray.experimental.channel import ChannelContext
from ray.experimental.channel.torch_tensor_nccl_channel import _init_nccl_group
from ray.experimental.channel.torch_tensor_type import GPUCommunicator, TorchTensorType
from ray.experimental.util.types import _CollectiveOp, ReduceOp
from ray.util.annotations import DeveloperAPI


class _CollectiveGroup:
class _CollectiveOperation:
"""
Represent metadata for a NCCL collective operation.
dengwxn marked this conversation as resolved.
Show resolved Hide resolved

Args:
input_nodes: A list of input nodes to the collective operation.
op: The collective operation to perform.
transport: The transport to use for the collective operation.

Requirements:
1. Input nodes are unique.
2. Actor handles are unique.
3. Actor handles match the custom NCCL group if specified.
"""

def __init__(
Expand All @@ -34,9 +39,9 @@ def __init__(
):
self._input_nodes: List[DAGNode] = input_nodes
if len(self._input_nodes) == 0:
raise ValueError("Expected input nodes for a collective group")
raise ValueError("Expected input nodes for a collective operation")
if len(set(self._input_nodes)) != len(self._input_nodes):
raise ValueError("Expected unique input nodes for a collective group")
raise ValueError("Expected unique input nodes for a collective operation")

self._actor_handles: List["ray.actor.ActorHandle"] = []
for input_node in self._input_nodes:
Expand All @@ -51,8 +56,9 @@ def __init__(
if self._actor_handles.count(input_node._get_actor_handle()) > 1
]
raise ValueError(
"Expected unique actor handles for a collective group, but found "
f"duplicate actor handles from input nodes: {invalid_input_nodes}"
"Expected unique actor handles for a collective operation, "
"but found duplicate actor handles from input nodes: "
f"{invalid_input_nodes}"
)

self._op = op
Expand Down Expand Up @@ -91,7 +97,6 @@ def init_nccl_group(self, nccl_group_id: Optional[str] = None) -> str:
"""
type_hint = self._type_hint
if type_hint.nccl_group_id is not None:
# The NCCL group has already been initialized.
return type_hint.nccl_group_id
if nccl_group_id is None:
nccl_group_id = _init_nccl_group(
Expand All @@ -110,7 +115,7 @@ def get_nccl_group(self) -> GPUCommunicator:
raise ValueError("Expected a NCCL group")
return nccl_group

def method(self, send_buf: "torch.Tensor") -> "torch.Tensor":
def execute(self, send_buf: "torch.Tensor") -> "torch.Tensor":
"""
Call the collective operation on the input tensor. An output tensor is
allocated and returned.
Expand Down Expand Up @@ -147,16 +152,13 @@ def __init__(
):
raise ValueError("Expected a single input node")
self._input_node = method_args[0]
# Parse the collective group.
self._collective_group: _CollectiveGroup = other_args_to_resolve.get(
COLLECTIVE_GROUP_KEY, None
# Parse the collective operation.
self._collective_op: _CollectiveOperation = other_args_to_resolve.get(
COLLECTIVE_OPERATION_KEY, None
)
if self._collective_group is None:
raise ValueError("Expected a collective group")
if self._collective_op is None:
raise ValueError("Expected a collective operation")

# The actor creation task dependency is encoded as the first argument,
# and the ordering dependency as the second, which ensures they are
# executed prior to this node.
super().__init__(
method_name,
method_args,
Expand Down Expand Up @@ -186,5 +188,5 @@ def _execute_impl(self, *args, **kwargs):
)

@property
def collective_group(self) -> _CollectiveGroup:
return self._collective_group
def collective_op(self) -> _CollectiveOperation:
return self._collective_op
134 changes: 74 additions & 60 deletions python/ray/dag/compiled_dag_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,10 +309,10 @@ def __init__(
self.input_type_hints: List[ChannelOutputType] = task.arg_type_hints
self.output_type_hint: ChannelOutputType = task.dag_node.type_hint

# The collective group that runs a NCCL collective operation.
self.collective_group: Optional["ray.dag.CollectiveGroup"] = None
# The NCCL collective operation.
self.collective_op: Optional["ray.dag.CollectiveOperation"] = None
if isinstance(task.dag_node, CollectiveOutputNode):
self.collective_group = task.dag_node.collective_group
self.collective_op = task.dag_node.collective_op

self.input_channels: List[ChannelInterface] = []
self.task_inputs: List[_ExecutableTaskInput] = []
Expand Down Expand Up @@ -451,9 +451,9 @@ def _compute(self, class_handle) -> bool:
for task_input in self.task_inputs:
resolved_inputs.append(task_input.resolve(input_data))

if self.collective_group is not None:
if self.collective_op is not None:
# Run a NCCL collective operation.
method = self.collective_group.method
method = self.collective_op.execute
else:
# Run an actor method.
method = getattr(class_handle, self.method_name)
Expand Down Expand Up @@ -685,7 +685,7 @@ def __init__(
self._use_default_nccl_group = False
# This is set to the specified custom nccl group
# if there exists a type hint of `transport=nccl_group`.
self._custom_nccl_group: Optional[GPUCommunicator] = None
self._custom_nccl_group_p2p: Optional[GPUCommunicator] = None
# The NCCL group ID for P2P send/recv operations.
self._nccl_group_id_p2p: Optional[str] = None
# All the NCCL group IDs for P2P send/recv and collective operations.
Expand Down Expand Up @@ -715,6 +715,14 @@ def _create_proxy_actor() -> "ray.actor.ActorHandle":

self._proxy_actor = _create_proxy_actor()

@property
def nccl_group_id_p2p(self) -> Optional[str]:
return self._nccl_group_id_p2p

@property
def nccl_group_ids(self) -> Set[str]:
return self._nccl_group_ids

def increment_max_finished_execution_index(self) -> None:
"""Increment the max finished execution index. It is used to
figure out the max number of in-flight requests to the DAG
Expand Down Expand Up @@ -752,13 +760,13 @@ def _preprocess(self) -> None:
InputNode,
MultiOutputNode,
)
from ray.dag.collective_node import _CollectiveGroup
from ray.dag.collective_node import _CollectiveOperation

self.input_task_idx, self.output_task_idx = None, None
self.actor_task_count.clear()

nccl_actors: Set["ray.actor.ActorHandle"] = set()
nccl_collective_groups: Set[_CollectiveGroup] = set()
nccl_actors_p2p: Set["ray.actor.ActorHandle"] = set()
nccl_collective_ops: Set[_CollectiveOperation] = set()

# Find the input node to the DAG.
for idx, task in self.idx_to_task.items():
Expand Down Expand Up @@ -833,7 +841,7 @@ def _preprocess(self) -> None:

# Collect actors for NCCL P2P methods.
if dag_node.type_hint.requires_nccl():
nccl_actors.add(actor_handle)
nccl_actors_p2p.add(actor_handle)
custom_nccl_group = dag_node.type_hint.get_custom_nccl_group()
mixed_nccl_group_error_message = (
"Accelerated DAGs do not support mixed usage of "
Expand All @@ -845,26 +853,26 @@ def _preprocess(self) -> None:
"make sure only one type of NCCL transport is specified."
)
if custom_nccl_group is None:
if self._custom_nccl_group is not None:
if self._custom_nccl_group_p2p is not None:
raise ValueError(mixed_nccl_group_error_message)
self._use_default_nccl_group = True
else:
if self._use_default_nccl_group:
raise ValueError(mixed_nccl_group_error_message)
if self._custom_nccl_group is not None:
if self._custom_nccl_group != custom_nccl_group:
if self._custom_nccl_group_p2p is not None:
if self._custom_nccl_group_p2p != custom_nccl_group:
raise ValueError(
"Accelerated DAGs currently only support "
"a single custom NCCL group, but multiple "
"have been specified. Check all the "
"TorchTensor(transport=nccl_group) type hints "
"to make sure only one NCCL group is used."
)
self._custom_nccl_group = custom_nccl_group
self._custom_nccl_group_p2p = custom_nccl_group

# Collect collective groups for NCCL collective operations.
# Collect NCCL collective operations.
if isinstance(dag_node, CollectiveOutputNode):
nccl_collective_groups.add(dag_node.collective_group)
nccl_collective_ops.add(dag_node.collective_op)
elif isinstance(dag_node, InputNode):
if dag_node.type_hint.requires_nccl():
raise ValueError(
Expand Down Expand Up @@ -935,72 +943,80 @@ def _preprocess(self) -> None:
task.arg_type_hints.append(upstream_task.dag_node.type_hint)

if upstream_task.dag_node.type_hint.requires_nccl():
# Add all readers to the NCCL group.
nccl_actors.add(downstream_actor_handle)
# Add all readers to the NCCL actors of P2P.
nccl_actors_p2p.add(downstream_actor_handle)

nccl_actors = list(nccl_actors)
if None in nccl_actors:
nccl_actors_p2p = list(nccl_actors_p2p)
if None in nccl_actors_p2p:
raise ValueError("Driver cannot participate in the NCCL group.")

# Initialize and cache a NCCL group for each custom NCCL group. All the
# custom NCCL groups are initialized before the default NCCL groups.
custom_nccl_group_to_id: Dict[GPUCommunicator, str] = {}
# Initialize and cache a NCCL group for each set of actors. A set of actors
# can perform P2P send/recv and collective operations. All the custom NCCL
# groups are initialized before the default NCCL groups. If there are
# multiple custom NCCL groups for a set of actors, only one is cached.
# can perform P2P send/recv and collective operations. If there are multiple
# custom NCCL groups for a set of actors, only one is cached.
actors_to_nccl_group_id: Dict[FrozenSet["ray.actor.ActorHandle"], str] = {}
# Initialize a NCCL group for each custom NCCL group.
custom_nccl_group_to_id: Dict[GPUCommunicator, str] = {}

# If a custom NCCL group is specified for P2P actors, initialize and cache
# the NCCL group ID.
if nccl_actors and self._custom_nccl_group:
if nccl_actors_p2p and self._custom_nccl_group_p2p:
if not set(nccl_actors_p2p).issubset(
set(self._custom_nccl_group_p2p.get_actor_handles())
):
raise ValueError(
"Expected P2P actor handles to be a subset of the custom NCCL group"
)
self._nccl_group_id_p2p = _init_nccl_group(
dengwxn marked this conversation as resolved.
Show resolved Hide resolved
nccl_actors, self._custom_nccl_group
nccl_actors_p2p, self._custom_nccl_group_p2p
)
actors = frozenset(nccl_actors)
custom_nccl_group_to_id[
self._custom_nccl_group_p2p
] = self._nccl_group_id_p2p
actors = frozenset(nccl_actors_p2p)
actors_to_nccl_group_id[actors] = self._nccl_group_id_p2p
custom_nccl_group_to_id[self._custom_nccl_group] = self._nccl_group_id_p2p

# If a custom NCCL group is specified for collective actors, initialize and
# cache the NCCL group ID.
for collective_group in nccl_collective_groups:
type_hint = collective_group.type_hint
for collective_op in nccl_collective_ops:
type_hint = collective_op.type_hint
custom_nccl_group = type_hint.get_custom_nccl_group()
if custom_nccl_group:
nccl_group_id = collective_group.init_nccl_group(
nccl_group_id = collective_op.init_nccl_group(
custom_nccl_group_to_id.get(custom_nccl_group, None)
)
actors = frozenset(collective_group.actor_handles)
custom_nccl_group_to_id[custom_nccl_group] = nccl_group_id
actors = frozenset(collective_op.actor_handles)
if actors not in actors_to_nccl_group_id:
actors_to_nccl_group_id[actors] = nccl_group_id
custom_nccl_group_to_id[custom_nccl_group] = nccl_group_id

# If a NCCL group for P2P actors is not initialized, initialize and cache
# the NCCL group ID.
if nccl_actors and self._nccl_group_id_p2p is None:
actors = frozenset(nccl_actors)
if nccl_actors_p2p and self._nccl_group_id_p2p is None:
actors = frozenset(nccl_actors_p2p)
if actors in actors_to_nccl_group_id:
self._nccl_group_id_p2p = actors_to_nccl_group_id[actors]
else:
self._nccl_group_id_p2p = _init_nccl_group(
nccl_actors, self._custom_nccl_group
nccl_actors_p2p, self._custom_nccl_group_p2p
)
actors_to_nccl_group_id[actors] = self._nccl_group_id_p2p

# If a NCCL group for collective actors is not initialized, initialize and
# cache the NCCL group ID.
for collective_group in nccl_collective_groups:
type_hint = collective_group.type_hint
if type_hint.nccl_group_id is None:
actors = frozenset(collective_group.actor_handles)
if actors in actors_to_nccl_group_id:
nccl_group_id = actors_to_nccl_group_id[actors]
type_hint.set_nccl_group_id(nccl_group_id)
else:
nccl_group_id = collective_group.init_nccl_group()
for collective_op in nccl_collective_ops:
if collective_op.type_hint.nccl_group_id is None:
actors = frozenset(collective_op.actor_handles)
nccl_group_id = collective_op.init_nccl_group(
actors_to_nccl_group_id.get(actors, None)
)
if actors not in actors_to_nccl_group_id:
actors_to_nccl_group_id[actors] = nccl_group_id

# Store all the NCCL group IDs for P2P send/recv and collective operations.
self._nccl_group_ids = set(actors_to_nccl_group_id.values())
self._nccl_group_ids = set(actors_to_nccl_group_id.values()).union(
set(custom_nccl_group_to_id.values())
)

if direct_input:
self._input_num_positional_args = 1
Expand Down Expand Up @@ -1460,19 +1476,19 @@ def _generate_dag_operation_graph_node(
]
}
"""
from ray.dag.collective_node import CollectiveOutputNode, _CollectiveGroup
from ray.dag.collective_node import CollectiveOutputNode, _CollectiveOperation

assert self.idx_to_task
assert self.actor_to_executable_tasks

actor_to_operation_nodes: Dict[
"ray.actor.ActorHandle", List[List[_DAGOperationGraphNode]]
] = defaultdict(list)
collective_group_to_nodes: Dict[
_CollectiveGroup, Set[_DAGOperationGraphNode]
collective_op_to_nodes: Dict[
_CollectiveOperation, Set[_DAGOperationGraphNode]
] = defaultdict(set)
collective_group_to_idxs: Dict[
_CollectiveGroup, Tuple[int, _DAGNodeOperationType]
collective_op_to_idxs: Dict[
_CollectiveOperation, Tuple[int, _DAGNodeOperationType]
] = defaultdict(set)

for actor_handle, executable_tasks in self.actor_to_executable_tasks.items():
Expand Down Expand Up @@ -1507,18 +1523,16 @@ def _generate_dag_operation_graph_node(
[read_node, compute_node, write_node]
)
if isinstance(dag_node, CollectiveOutputNode):
collective_group_to_nodes[dag_node.collective_group].add(
compute_node
)
collective_group_to_idxs[dag_node.collective_group].add(
collective_op_to_nodes[dag_node.collective_op].add(compute_node)
collective_op_to_idxs[dag_node.collective_op].add(
(task_idx, _DAGNodeOperationType.COMPUTE)
)

# Set collective group nodes for all the NCCL collective nodes.
for collective_group, nodes in collective_group_to_nodes.items():
idxs = collective_group_to_idxs[collective_group]
# Set collective nodes for all the NCCL collective operation nodes.
for collective_op, nodes in collective_op_to_nodes.items():
idxs = collective_op_to_idxs[collective_op]
for node in nodes:
node.set_collective_group_idxs(idxs)
node.collective_idxs = idxs

return actor_to_operation_nodes

Expand Down
Loading