diff --git a/python/ray/dag/BUILD b/python/ray/dag/BUILD index 7431e6ce4fad..aef05a3d76fc 100644 --- a/python/ray/dag/BUILD +++ b/python/ray/dag/BUILD @@ -105,6 +105,7 @@ py_test_module_list( "tests/experimental/test_detect_deadlock_dag.py", "tests/experimental/test_multi_node_dag.py", "tests/experimental/test_torch_tensor_dag.py", + "tests/experimental/test_collective_dag.py", "tests/experimental/test_execution_schedule.py", ], tags = [ @@ -132,7 +133,7 @@ py_test_module_list( py_test( name = "test_torch_tensor_dag_gpu", - size = "medium", + size = "large", srcs = [ "tests/experimental/test_torch_tensor_dag.py", ], diff --git a/python/ray/dag/__init__.py b/python/ray/dag/__init__.py index ee5ef07ba821..eb13abc5a53e 100644 --- a/python/ray/dag/__init__.py +++ b/python/ray/dag/__init__.py @@ -4,6 +4,7 @@ ClassNode, ClassMethodNode, ) +from ray.dag.collective_node import CollectiveOutputNode from ray.dag.input_node import ( InputNode, InputAttributeNode, @@ -13,6 +14,9 @@ from ray.dag.constants import ( PARENT_CLASS_NODE_KEY, PREV_CLASS_METHOD_CALL_KEY, + BIND_INDEX_KEY, + IS_CLASS_METHOD_OUTPUT_KEY, + COLLECTIVE_OPERATION_KEY, DAGNODE_TYPE_KEY, ) from ray.dag.vis_utils import plot @@ -21,6 +25,7 @@ __all__ = [ "ClassNode", "ClassMethodNode", + "CollectiveOutputNode", "DAGNode", "FunctionNode", "InputNode", @@ -28,6 +33,9 @@ "DAGInputData", "PARENT_CLASS_NODE_KEY", "PREV_CLASS_METHOD_CALL_KEY", + "BIND_INDEX_KEY", + "IS_CLASS_METHOD_OUTPUT_KEY", + "COLLECTIVE_OPERATION_KEY", "DAGNODE_TYPE_KEY", "plot", "MultiOutputNode", diff --git a/python/ray/dag/collective_node.py b/python/ray/dag/collective_node.py new file mode 100644 index 000000000000..37c7087e51d1 --- /dev/null +++ b/python/ray/dag/collective_node.py @@ -0,0 +1,192 @@ +from typing import Any, Dict, List, Union, Tuple, Optional, TYPE_CHECKING + +if TYPE_CHECKING: + import torch + +import ray +from ray.dag import ( + DAGNode, + ClassMethodNode, +) +from ray.dag.constants import COLLECTIVE_OPERATION_KEY +from ray.experimental.channel import ChannelContext +from ray.experimental.channel.torch_tensor_nccl_channel import _init_nccl_group +from ray.experimental.channel.torch_tensor_type import GPUCommunicator, TorchTensorType +from ray.experimental.util.types import _CollectiveOp, ReduceOp +from ray.util.annotations import DeveloperAPI + + +class _CollectiveOperation: + """ + Represent metadata for a NCCL collective operation. + + Args: + input_nodes: A list of input nodes to the collective operation. + op: The collective operation to perform. + transport: The transport to use for the collective operation. + + Requirements: + 1. Input nodes are unique. + 2. Actor handles are unique. + 3. Actor handles match the custom NCCL group if specified. + """ + + def __init__( + self, + input_nodes: List[DAGNode], + op: _CollectiveOp, + transport: Optional[Union[str, GPUCommunicator]] = None, + ): + self._input_nodes: List[DAGNode] = input_nodes + if len(self._input_nodes) == 0: + raise ValueError("Expected input nodes for a collective operation") + if len(set(self._input_nodes)) != len(self._input_nodes): + raise ValueError("Expected unique input nodes for a collective operation") + + self._actor_handles: List["ray.actor.ActorHandle"] = [] + for input_node in self._input_nodes: + actor_handle = input_node._get_actor_handle() + if actor_handle is None: + raise ValueError("Expected an actor handle from the input node") + self._actor_handles.append(actor_handle) + if len(set(self._actor_handles)) != len(self._actor_handles): + invalid_input_nodes = [ + input_node + for input_node in self._input_nodes + if self._actor_handles.count(input_node._get_actor_handle()) > 1 + ] + raise ValueError( + "Expected unique actor handles for a collective operation, " + "but found duplicate actor handles from input nodes: " + f"{invalid_input_nodes}" + ) + + self._op = op + if not isinstance(self._op, ReduceOp): + raise NotImplementedError("Only ReduceOp is implemented") + if transport is None: + transport = TorchTensorType.NCCL + self._type_hint = TorchTensorType(transport=transport, _direct_return=True) + if isinstance(transport, GPUCommunicator): + if set(transport.get_actor_handles()) != set(self._actor_handles): + raise ValueError( + "Expected actor handles to match the custom NCCL group" + ) + + def __str__(self) -> str: + return ( + f"CollectiveGroup(" + f"_input_nodes={self._input_nodes}, " + f"_actor_handles={self._actor_handles}, " + f"_op={self._op}, " + f"_type_hint={self._type_hint})" + ) + + @property + def actor_handles(self) -> List["ray.actor.ActorHandle"]: + return self._actor_handles + + @property + def type_hint(self) -> TorchTensorType: + return self._type_hint + + def init_nccl_group(self, nccl_group_id: Optional[str] = None) -> str: + """ + Initialize the NCCL group if it has not been initialized yet. If `nccl_group_id` + is provided, it means the NCCL group has already been initialized. + """ + type_hint = self._type_hint + if type_hint.nccl_group_id is not None: + return type_hint.nccl_group_id + if nccl_group_id is None: + nccl_group_id = _init_nccl_group( + self._actor_handles, type_hint.get_custom_nccl_group() + ) + type_hint.set_nccl_group_id(nccl_group_id) + return nccl_group_id + + def get_nccl_group(self) -> GPUCommunicator: + if self._type_hint.nccl_group_id is not None: + ctx = ChannelContext.get_current() + nccl_group = ctx.nccl_groups[self._type_hint.nccl_group_id] + elif self._type_hint.get_custom_nccl_group() is not None: + nccl_group = self._type_hint.get_custom_nccl_group() + else: + raise ValueError("Expected a NCCL group") + return nccl_group + + def execute(self, send_buf: "torch.Tensor") -> "torch.Tensor": + """ + Call the collective operation on the input tensor. An output tensor is + allocated and returned. + """ + import torch + + if not isinstance(send_buf, torch.Tensor): + raise ValueError("Expected a torch tensor") + nccl_group = self.get_nccl_group() + recv_buf = torch.empty_like(send_buf) + nccl_group.allreduce(send_buf, recv_buf, self._op) + return recv_buf + + +@DeveloperAPI +class CollectiveOutputNode(ClassMethodNode): + """Represent an output node from a NCCL collective operation in a Ray DAG.""" + + def __init__( + self, + method_name: str, + method_args: Tuple[ + DAGNode, + ], + method_kwargs: Dict[str, Any], + method_options: Dict[str, Any], + other_args_to_resolve: Dict[str, Any], + ): + # Parse the input node. + if not ( + isinstance(method_args, tuple) + and len(method_args) == 1 + and isinstance(method_args[0], DAGNode) + ): + raise ValueError("Expected a single input node") + self._input_node = method_args[0] + # Parse the collective operation. + self._collective_op: _CollectiveOperation = other_args_to_resolve.get( + COLLECTIVE_OPERATION_KEY, None + ) + if self._collective_op is None: + raise ValueError("Expected a collective operation") + + super().__init__( + method_name, + method_args, + method_kwargs, + method_options, + other_args_to_resolve, + ) + + def _copy_impl( + self, + new_args: List[Any], + new_kwargs: Dict[str, Any], + new_options: Dict[str, Any], + new_other_args_to_resolve: Dict[str, Any], + ): + return CollectiveOutputNode( + self._method_name, + new_args, + new_kwargs, + new_options, + other_args_to_resolve=new_other_args_to_resolve, + ) + + def _execute_impl(self, *args, **kwargs): + raise NotImplementedError( + "CollectiveOutputNode is only supported with dag.experimental_compile()" + ) + + @property + def collective_op(self) -> _CollectiveOperation: + return self._collective_op diff --git a/python/ray/dag/compiled_dag_node.py b/python/ray/dag/compiled_dag_node.py index 0e56c1f4ea52..ba6e6e4a6c6a 100644 --- a/python/ray/dag/compiled_dag_node.py +++ b/python/ray/dag/compiled_dag_node.py @@ -1,7 +1,7 @@ import asyncio from collections import defaultdict from dataclasses import dataclass, asdict -from typing import Any, Dict, List, Tuple, Union, Optional, Set +from typing import Any, Dict, FrozenSet, List, Tuple, Union, Optional, Set import logging import threading import time @@ -300,12 +300,19 @@ def __init__( do not support binding kwargs to other DAG nodes, so the values of the dictionary cannot be Channels. """ + from ray.dag import CollectiveOutputNode + self.method_name = task.dag_node.get_method_name() self.bind_index = task.dag_node._get_bind_index() self.output_channels = task.output_channels self.output_idxs = task.output_idxs - self.input_type_hints: List["ChannelOutputType"] = task.arg_type_hints - self.output_type_hint: "ChannelOutputType" = task.dag_node.type_hint + self.input_type_hints: List[ChannelOutputType] = task.arg_type_hints + self.output_type_hint: ChannelOutputType = task.dag_node.type_hint + + # The NCCL collective operation. + self.collective_op: Optional["ray.dag.CollectiveOperation"] = None + if isinstance(task.dag_node, CollectiveOutputNode): + self.collective_op = task.dag_node.collective_op self.input_channels: List[ChannelInterface] = [] self.task_inputs: List[_ExecutableTaskInput] = [] @@ -430,7 +437,6 @@ def _compute(self, class_handle) -> bool: True if system error occurs and exit the loop; otherwise, False. """ input_data = self.reset_intermediate_buffer() - method = getattr(class_handle, self.method_name) try: _process_return_vals(input_data, return_single_output=False) except Exception as exc: @@ -445,6 +451,12 @@ def _compute(self, class_handle) -> bool: for task_input in self.task_inputs: resolved_inputs.append(task_input.resolve(input_data)) + if self.collective_op is not None: + # Run a NCCL collective operation. + method = self.collective_op.execute + else: + # Run an actor method. + method = getattr(class_handle, self.method_name) try: output_val = method(*resolved_inputs, **self.resolved_kwargs) except Exception as exc: @@ -669,14 +681,15 @@ def __init__( # Mapping from the actor handle to the node ID that the actor is on. self.actor_to_node_id: Dict["ray.actor.ActorHandle", str] = {} - # This is set to true when type hint of `transport="nccl"`` is used + # This is set to true when type hint of `transport="nccl"` is used. self._use_default_nccl_group = False # This is set to the specified custom nccl group - # if there exists a type hint of `transport=nccl_group` - self._custom_nccl_group: Optional[GPUCommunicator] = None - # Uniquely identifies the NCCL communicator that will be used within - # this DAG, if any. - self._nccl_group_id: Optional[str] = None + # if there exists a type hint of `transport=nccl_group`. + self._custom_nccl_group_p2p: Optional[GPUCommunicator] = None + # The NCCL group ID for P2P send/recv operations. + self._nccl_group_id_p2p: Optional[str] = None + # All the NCCL group IDs for P2P send/recv and collective operations. + self._nccl_group_ids: Set[str] = set() # The index of the current execution. It is incremented each time # the DAG is executed. self._execution_index: int = 0 @@ -702,6 +715,14 @@ def _create_proxy_actor() -> "ray.actor.ActorHandle": self._proxy_actor = _create_proxy_actor() + @property + def nccl_group_id_p2p(self) -> Optional[str]: + return self._nccl_group_id_p2p + + @property + def nccl_group_ids(self) -> Set[str]: + return self._nccl_group_ids + def increment_max_finished_execution_index(self) -> None: """Increment the max finished execution index. It is used to figure out the max number of in-flight requests to the DAG @@ -733,16 +754,19 @@ def _preprocess(self) -> None: from ray.dag import ( DAGNode, ClassMethodNode, + CollectiveOutputNode, FunctionNode, InputAttributeNode, InputNode, MultiOutputNode, ) + from ray.dag.collective_node import _CollectiveOperation self.input_task_idx, self.output_task_idx = None, None self.actor_task_count.clear() - nccl_actors: Set["ray.actor.ActorHandle"] = set() + nccl_actors_p2p: Set["ray.actor.ActorHandle"] = set() + nccl_collective_ops: Set[_CollectiveOperation] = set() # Find the input node to the DAG. for idx, task in self.idx_to_task.items(): @@ -815,9 +839,9 @@ def _preprocess(self) -> None: self.actor_task_count[actor_handle._actor_id] += 1 + # Collect actors for NCCL P2P methods. if dag_node.type_hint.requires_nccl(): - # Add all writers to the NCCL group. - nccl_actors.add(actor_handle) + nccl_actors_p2p.add(actor_handle) custom_nccl_group = dag_node.type_hint.get_custom_nccl_group() mixed_nccl_group_error_message = ( "Accelerated DAGs do not support mixed usage of " @@ -829,14 +853,14 @@ def _preprocess(self) -> None: "make sure only one type of NCCL transport is specified." ) if custom_nccl_group is None: - if self._custom_nccl_group is not None: + if self._custom_nccl_group_p2p is not None: raise ValueError(mixed_nccl_group_error_message) self._use_default_nccl_group = True else: if self._use_default_nccl_group: raise ValueError(mixed_nccl_group_error_message) - if self._custom_nccl_group is not None: - if self._custom_nccl_group != custom_nccl_group: + if self._custom_nccl_group_p2p is not None: + if self._custom_nccl_group_p2p != custom_nccl_group: raise ValueError( "Accelerated DAGs currently only support " "a single custom NCCL group, but multiple " @@ -844,7 +868,11 @@ def _preprocess(self) -> None: "TorchTensor(transport=nccl_group) type hints " "to make sure only one NCCL group is used." ) - self._custom_nccl_group = custom_nccl_group + self._custom_nccl_group_p2p = custom_nccl_group + + # Collect NCCL collective operations. + if isinstance(dag_node, CollectiveOutputNode): + nccl_collective_ops.add(dag_node.collective_op) elif isinstance(dag_node, InputNode): if dag_node.type_hint.requires_nccl(): raise ValueError( @@ -915,16 +943,80 @@ def _preprocess(self) -> None: task.arg_type_hints.append(upstream_task.dag_node.type_hint) if upstream_task.dag_node.type_hint.requires_nccl(): - # Add all readers to the NCCL group. - nccl_actors.add(downstream_actor_handle) + # Add all readers to the NCCL actors of P2P. + nccl_actors_p2p.add(downstream_actor_handle) - # If there were type hints indicating transport via NCCL, initialize - # the NCCL group on the participating actors. - nccl_actors = list(nccl_actors) - if None in nccl_actors: + nccl_actors_p2p = list(nccl_actors_p2p) + if None in nccl_actors_p2p: raise ValueError("Driver cannot participate in the NCCL group.") - if nccl_actors and self._nccl_group_id is None: - self._nccl_group_id = _init_nccl_group(nccl_actors, self._custom_nccl_group) + + # Initialize and cache a NCCL group for each custom NCCL group. All the + # custom NCCL groups are initialized before the default NCCL groups. + custom_nccl_group_to_id: Dict[GPUCommunicator, str] = {} + # Initialize and cache a NCCL group for each set of actors. A set of actors + # can perform P2P send/recv and collective operations. If there are multiple + # custom NCCL groups for a set of actors, only one is cached. + actors_to_nccl_group_id: Dict[FrozenSet["ray.actor.ActorHandle"], str] = {} + + # If a custom NCCL group is specified for P2P actors, initialize and cache + # the NCCL group ID. + if nccl_actors_p2p and self._custom_nccl_group_p2p: + if not set(nccl_actors_p2p).issubset( + set(self._custom_nccl_group_p2p.get_actor_handles()) + ): + raise ValueError( + "Expected P2P actor handles to be a subset of the custom NCCL group" + ) + self._nccl_group_id_p2p = _init_nccl_group( + nccl_actors_p2p, self._custom_nccl_group_p2p + ) + custom_nccl_group_to_id[ + self._custom_nccl_group_p2p + ] = self._nccl_group_id_p2p + actors = frozenset(nccl_actors_p2p) + actors_to_nccl_group_id[actors] = self._nccl_group_id_p2p + + # If a custom NCCL group is specified for collective actors, initialize and + # cache the NCCL group ID. + for collective_op in nccl_collective_ops: + type_hint = collective_op.type_hint + custom_nccl_group = type_hint.get_custom_nccl_group() + if custom_nccl_group: + nccl_group_id = collective_op.init_nccl_group( + custom_nccl_group_to_id.get(custom_nccl_group, None) + ) + custom_nccl_group_to_id[custom_nccl_group] = nccl_group_id + actors = frozenset(collective_op.actor_handles) + if actors not in actors_to_nccl_group_id: + actors_to_nccl_group_id[actors] = nccl_group_id + + # If a NCCL group for P2P actors is not initialized, initialize and cache + # the NCCL group ID. + if nccl_actors_p2p and self._nccl_group_id_p2p is None: + actors = frozenset(nccl_actors_p2p) + if actors in actors_to_nccl_group_id: + self._nccl_group_id_p2p = actors_to_nccl_group_id[actors] + else: + self._nccl_group_id_p2p = _init_nccl_group( + nccl_actors_p2p, self._custom_nccl_group_p2p + ) + actors_to_nccl_group_id[actors] = self._nccl_group_id_p2p + + # If a NCCL group for collective actors is not initialized, initialize and + # cache the NCCL group ID. + for collective_op in nccl_collective_ops: + if collective_op.type_hint.nccl_group_id is None: + actors = frozenset(collective_op.actor_handles) + nccl_group_id = collective_op.init_nccl_group( + actors_to_nccl_group_id.get(actors, None) + ) + if actors not in actors_to_nccl_group_id: + actors_to_nccl_group_id[actors] = nccl_group_id + + # Store all the NCCL group IDs for P2P send/recv and collective operations. + self._nccl_group_ids = set(actors_to_nccl_group_id.values()).union( + set(custom_nccl_group_to_id.values()) + ) if direct_input: self._input_num_positional_args = 1 @@ -999,7 +1091,7 @@ def _get_or_compile( task = self.idx_to_task[cur_idx] type_hint = task.dag_node.type_hint if type_hint.requires_nccl(): - type_hint.set_nccl_group_id(self._nccl_group_id) + type_hint.set_nccl_group_id(self._nccl_group_id_p2p) if ( isinstance(task.dag_node, ClassMethodNode) @@ -1384,43 +1476,64 @@ def _generate_dag_operation_graph_node( ] } """ + from ray.dag.collective_node import CollectiveOutputNode, _CollectiveOperation + assert self.idx_to_task assert self.actor_to_executable_tasks actor_to_operation_nodes: Dict[ "ray.actor.ActorHandle", List[List[_DAGOperationGraphNode]] ] = defaultdict(list) + collective_op_to_nodes: Dict[ + _CollectiveOperation, Set[_DAGOperationGraphNode] + ] = defaultdict(set) + collective_op_to_idxs: Dict[ + _CollectiveOperation, Tuple[int, _DAGNodeOperationType] + ] = defaultdict(set) for actor_handle, executable_tasks in self.actor_to_executable_tasks.items(): for exec_task_idx, exec_task in enumerate(executable_tasks): # Divide a DAG node into three _DAGOperationGraphNodes: READ, COMPUTE, # and WRITE. Each _DAGOperationGraphNode has a _DAGNodeOperation. - task_index = exec_task.task_idx - dag_node = self.idx_to_task[task_index].dag_node + task_idx = exec_task.task_idx + dag_node = self.idx_to_task[task_idx].dag_node actor_handle = dag_node._get_actor_handle() requires_nccl = dag_node.type_hint.requires_nccl() read_node = _DAGOperationGraphNode( _DAGNodeOperation(exec_task_idx, _DAGNodeOperationType.READ), - task_index, + task_idx, actor_handle, requires_nccl, ) compute_node = _DAGOperationGraphNode( _DAGNodeOperation(exec_task_idx, _DAGNodeOperationType.COMPUTE), - task_index, + task_idx, actor_handle, - requires_nccl, + isinstance(dag_node, CollectiveOutputNode), ) write_node = _DAGOperationGraphNode( _DAGNodeOperation(exec_task_idx, _DAGNodeOperationType.WRITE), - task_index, + task_idx, actor_handle, requires_nccl, ) + actor_to_operation_nodes[actor_handle].append( [read_node, compute_node, write_node] ) + if isinstance(dag_node, CollectiveOutputNode): + collective_op_to_nodes[dag_node.collective_op].add(compute_node) + collective_op_to_idxs[dag_node.collective_op].add( + (task_idx, _DAGNodeOperationType.COMPUTE) + ) + + # Set collective nodes for all the NCCL collective operation nodes. + for collective_op, nodes in collective_op_to_nodes.items(): + idxs = collective_op_to_idxs[collective_op] + for node in nodes: + node.collective_idxs = idxs + return actor_to_operation_nodes def _build_execution_schedule( @@ -1551,7 +1664,7 @@ def wait_teardown(self): try: ray.get(ref, timeout=10) except ray.exceptions.GetTimeoutError: - logger.warn( + logger.warning( f"Compiled DAG actor {actor} is still running 10s " "after teardown(). Teardown may hang." ) @@ -1603,8 +1716,8 @@ def teardown(self, wait: bool): logger.exception("Error cancelling worker task") pass - if outer._nccl_group_id is not None: - _destroy_nccl_group(outer._nccl_group_id) + for nccl_group_id in outer._nccl_group_ids: + _destroy_nccl_group(nccl_group_id) if wait: logger.info("Waiting for worker tasks to exit") diff --git a/python/ray/dag/constants.py b/python/ray/dag/constants.py index f1077c7b5104..ed86adf914ed 100644 --- a/python/ray/dag/constants.py +++ b/python/ray/dag/constants.py @@ -6,6 +6,9 @@ BIND_INDEX_KEY = "bind_index" IS_CLASS_METHOD_OUTPUT_KEY = "is_class_method_output" +# Reserved keys used to handle CollectiveOutputNode in Ray DAG building. +COLLECTIVE_OPERATION_KEY = "collective_operation" + # Reserved key to distinguish DAGNode type and avoid collision with user dict. DAGNODE_TYPE_KEY = "__dag_node_type__" diff --git a/python/ray/dag/dag_node.py b/python/ray/dag/dag_node.py index 031709382eda..4305186e0fc3 100644 --- a/python/ray/dag/dag_node.py +++ b/python/ray/dag/dag_node.py @@ -71,7 +71,7 @@ def __init__( # Cached values from last call to execute() self.cache_from_last_execute = {} - self._type_hint: Optional[ChannelOutputType] = ChannelOutputType() + self._type_hint: ChannelOutputType = ChannelOutputType() # Whether this node calls `experimental_compile`. self.is_adag_output_node = False @@ -112,7 +112,7 @@ def with_type_hint(self, typ: ChannelOutputType): return self @property - def type_hint(self) -> Optional[ChannelOutputType]: + def type_hint(self) -> ChannelOutputType: return self._type_hint def get_args(self) -> Tuple[Any]: diff --git a/python/ray/dag/dag_node_operation.py b/python/ray/dag/dag_node_operation.py index dcd55ae2c702..8f03547702be 100644 --- a/python/ray/dag/dag_node_operation.py +++ b/python/ray/dag/dag_node_operation.py @@ -1,6 +1,6 @@ from functools import total_ordering from enum import Enum -from typing import Set, Tuple, List, Dict +from typing import Set, Tuple, List, Dict, Optional import ray import heapq from collections import defaultdict @@ -37,7 +37,11 @@ def __init__( self.type = operation_type def __repr__(self): - return f"(Task idx: {self.exec_task_idx}, Type: {self.type})" + return ( + f"_DAGNodeOperation(" + f"exec_task_idx: {self.exec_task_idx}, " + f" type: {self.type})" + ) @total_ordering @@ -72,10 +76,23 @@ def __init__( # be READ, COMPUTE, or WRITE. self.in_edges: Set[Tuple[int, _DAGNodeOperationType]] = set() self.out_edges: Set[Tuple[int, _DAGNodeOperationType]] = set() + # The collective nodes are the nodes that belong to the same collective + # operation. Each node is represented by a tuple of its task idx and type. + self.collective_idxs: Set[Tuple[int, _DAGNodeOperationType]] = set() + # The ready collective nodes are the nodes that are ready to be executed, + # i.e., their in-degrees are zero. When a collective node is ready, it + # will be added to the ready collective nodes of all the nodes in its + # collective operation. + self.ready_collective_idxs: Set[Tuple[int, _DAGNodeOperationType]] = set() - @property - def in_degree(self) -> int: - return len(self.in_edges) + def __repr__(self): + return ( + f"_DAGOperationGraphNode(" + f"operation: {self.operation}, " + f"task_idx: {self.task_idx}, " + f"actor_handle: {self.actor_handle}, " + f"requires_nccl: {self.requires_nccl})" + ) def __lt__(self, other: "_DAGOperationGraphNode"): """ @@ -83,27 +100,28 @@ def __lt__(self, other: "_DAGOperationGraphNode"): `_select_next_nodes`. The priority queue is a min-heap, so the node with higher priority is considered "less than" the other node. """ - # If two nodes belong to the same actor, select the one with - # the smaller `exec_task_idx`. + + def compare(lhs: "_DAGOperationGraphNode", rhs: "_DAGOperationGraphNode"): + # If both nodes belong to the same actor, the node with the smaller + # `exec_task_idx` is prioritized. If two nodes belong to different + # actors, it approximates balancing the scheduled tasks across actors, + # by prioritizing the node with the smaller `exec_task_idx`. The tie + # is broken by the `task_idx`. + if lhs.operation.exec_task_idx != rhs.operation.exec_task_idx: + return lhs.operation.exec_task_idx < rhs.operation.exec_task_idx + return lhs.task_idx < rhs.task_idx + if self.actor_handle == other.actor_handle: - return self.operation.exec_task_idx < other.operation.exec_task_idx - # If two nodes belong to different actors and one of them is an NCCL - # write node, select the one that is not an NCCL write node. - is_nccl_write = ( - self.operation.type == _DAGNodeOperationType.WRITE and self.requires_nccl - ) - other_is_nccl_write = ( - other.operation.type == _DAGNodeOperationType.WRITE and other.requires_nccl - ) - if is_nccl_write != other_is_nccl_write: - return not is_nccl_write - # If two nodes belong to different actors and both are either NCCL write - # nodes or neither are NCCL write nodes, select the one with the smaller - # `exec_task_idx`. If they have the same `exec_task_idx`, select the one - # with the smaller `task_idx`. - if self.operation.exec_task_idx != other.operation.exec_task_idx: - return self.operation.exec_task_idx < other.operation.exec_task_idx - return self.task_idx < other.task_idx + # When both nodes belong to the same actor, use the default comparison. + return compare(self, other) + elif self.is_nccl_op != other.is_nccl_op: + # When one node is a NCCL operation and the other is not, prioritize + # the non-NCCL operation. + return not self.is_nccl_op + else: + # When either both nodes are NCCL operations or both nodes are not + # NCCL operations, use the default comparison. + return compare(self, other) def __eq__(self, other: "_DAGOperationGraphNode"): """ @@ -122,6 +140,45 @@ def __hash__(self): """ return hash((self.operation, self.task_idx)) + @property + def in_degree(self) -> int: + return len(self.in_edges) + + @property + def is_ready(self) -> bool: + """ + If a node is not a NCCL collective, it is ready when it has a zero + in-degree. If it is a NCCL collective, it is ready when all the nodes + in its collective operation have zero in-degrees. + """ + return self.in_degree == 0 and ( + len(self.ready_collective_idxs) == len(self.collective_idxs) + ) + + @property + def is_read(self) -> bool: + return self.operation.type == _DAGNodeOperationType.READ + + @property + def is_nccl_collective(self) -> bool: + """ + A node is a NCCL collective if it is a compute node and requires NCCL. + """ + return ( + self.operation.type == _DAGNodeOperationType.COMPUTE and self.requires_nccl + ) + + @property + def is_nccl_write(self) -> bool: + """ + A node is a NCCL write if it is a write node and requires NCCL. + """ + return self.operation.type == _DAGNodeOperationType.WRITE and self.requires_nccl + + @property + def is_nccl_op(self) -> bool: + return self.is_nccl_collective or self.is_nccl_write + def _add_edge(from_node: _DAGOperationGraphNode, to_node: _DAGOperationGraphNode): """ @@ -132,35 +189,53 @@ def _add_edge(from_node: _DAGOperationGraphNode, to_node: _DAGOperationGraphNode to_node.in_edges.add((from_node.task_idx, from_node.operation.type)) -def _select_next_nodes( +def _push_candidate_node_if_ready( actor_to_candidates: Dict["ray._raylet.ActorID", List[_DAGOperationGraphNode]], graph: Dict[int, Dict[_DAGNodeOperationType, _DAGOperationGraphNode]], -): - """ - This function selects the next nodes for topological sort to generate execution - schedule. If there are multiple candidate _DAGOperationGraphNodes, select the node - with the top priority based on the following rules: - - #1 If two candidate nodes belong to the same actor, select the one with - the smaller `exec_task_idx`. + node: _DAGOperationGraphNode, +) -> None: + # Collective operations are ready when all the collective nodes have zero + # in-degrees. Only one node per collective will be added as ready. + if node.is_nccl_collective: + for collective_node_metadata in node.collective_idxs: + task_idx, op_type = collective_node_metadata + collective_node = graph[task_idx][op_type] + collective_node.ready_collective_idxs.add( + (node.task_idx, node.operation.type) + ) + if node.is_ready: + heapq.heappush( + actor_to_candidates[node.actor_handle._actor_id], + node, + ) - #2 If two candidate nodes belong to different actors and both are either NCCL - write nodes or neither are NCCL write nodes, select the one with the smaller - `exec_task_idx`. If they have the same `exec_task_idx`, select the one with the - smaller `task_idx`. - #3 If two candidate nodes belong to different actors and one of them is an NCCL - write node, select the one that is not an NCCL write node. +def _select_next_nodes( + actor_to_candidates: Dict["ray._raylet.ActorID", List[_DAGOperationGraphNode]], + graph: Dict[int, Dict[_DAGNodeOperationType, _DAGOperationGraphNode]], +) -> Optional[List[_DAGOperationGraphNode]]: + """ + This function selects the next nodes for the topological sort to generate + execution schedule. If there are multiple candidate _DAGOperationGraphNodes, + select the node with the top priority. The priority is defined in + `_DAGOperationGraphNode.__lt__`. For the implementation details, we maintain a priority queue for each actor, where the head of the priority queue is the node with the smallest `exec_task_idx`. + When a node has a zero in-degree, it is added to the corresponding actor's + priority queue. For a node other than a NCCL collective node, it is ready to be + executed if it has a zero in-degree. For a NCCL collective node, it is ready to + be executed when all the nodes in its collective operation have zero in-degrees. - If the selected node is an NCCL write node, select all its immediately downstream - nodes, which are NCCL read nodes, regardless of whether the downstream nodes are - heads of their own priority queues. In that case, this function only removes the - NCCL write node, which is also the head of a priority queue. Other nodes will be - removed in the following iterations. The NCCL read nodes will be returned even - though they should not yet be in the candidate list. + If a node is a NCCL collective node, it updates the `ready_collective_nodes` of + all the nodes in its collective operation. Unless all the nodes in its collective + group have zero in-degrees, this node is removed from the candidate list. + Eventually, exactly one NCCL collective node from its collective operation is + selected from the candidate list. + + If the selected node is a NCCL write node, select all the downstream NCCL + read nodes. If the selected node is a NCCL collective node, select all the NCCL + compute nodes in its collective operation. Args: actor_to_candidates: A dictionary mapping an actor id to a list of @@ -175,32 +250,42 @@ def _select_next_nodes( execution schedules. """ top_priority_node = None - next_nodes: List[_DAGOperationGraphNode] = [] for _, candidates in actor_to_candidates.items(): if len(candidates) == 0: continue if top_priority_node is None or candidates[0] < top_priority_node: top_priority_node = candidates[0] - assert top_priority_node is not None - next_nodes.append( + + if top_priority_node is None: + return None + next_nodes = [ heapq.heappop(actor_to_candidates[top_priority_node.actor_handle._actor_id]) - ) + ] - if not ( - top_priority_node.operation.type == _DAGNodeOperationType.WRITE - and top_priority_node.requires_nccl - ): + if not top_priority_node.is_nccl_op: + # A non-NCCL operation node is picked. assert len(next_nodes) == 1 - return next_nodes - - # An NCCL write node is picked. NCCL is a blocking operation, so we need to pick all - # the corresponding NCCL read nodes to avoid a deadlock. - for downstream_node_metadata in top_priority_node.out_edges: - task_idx, op_type = downstream_node_metadata[0], downstream_node_metadata[1] - downstream_node = graph[task_idx][op_type] - assert downstream_node.operation.type == _DAGNodeOperationType.READ - next_nodes.append(downstream_node) - assert len(next_nodes) == 1 + len(top_priority_node.out_edges) + elif top_priority_node.is_nccl_write: + # a NCCL write node is picked. NCCL is a blocking operation, so we need + # to pick all the corresponding NCCL read nodes to avoid a deadlock. + for downstream_node_metadata in top_priority_node.out_edges: + task_idx, op_type = downstream_node_metadata + downstream_node = graph[task_idx][op_type] + assert downstream_node.is_read + next_nodes.append(downstream_node) + assert len(next_nodes) == 1 + len(top_priority_node.out_edges) + elif top_priority_node.is_nccl_collective: + # a NCCL collective node is picked. NCCL is a blocking operation, so we need + # to pick all the corresponding NCCL collective nodes in its collective + # operation to avoid a deadlock. + for collective_node_metadata in top_priority_node.collective_idxs: + task_idx, op_type = collective_node_metadata + collective_node = graph[task_idx][op_type] + assert collective_node.is_nccl_collective and collective_node.is_ready + if collective_node != top_priority_node: + next_nodes.append(collective_node) + assert len(next_nodes) == len(top_priority_node.collective_idxs) + return next_nodes @@ -268,23 +353,28 @@ def _build_dag_node_operation_graph( } # Import `ray.dag` here to avoid circular import. - from ray.dag import ClassMethodNode, MultiOutputNode + from ray.dag import ClassMethodNode, CollectiveOutputNode, MultiOutputNode # Add an edge from WRITE of the writer task to READ of the reader task. for task_idx, task in idx_to_task.items(): - if ( + if not ( isinstance(task.dag_node, ClassMethodNode) - and task.dag_node.is_class_method_output + or isinstance(task.dag_node, CollectiveOutputNode) ): - # TODO(wxdeng): Handle the case where the task is a class method output. - continue - if not isinstance(task.dag_node, ClassMethodNode): # The graph is used to generate an execution schedule for each actor. # The edge from the InputNode has no impact on the final execution # schedule. continue + if ( + isinstance(task.dag_node, ClassMethodNode) + and task.dag_node.is_class_method_output + ): + # TODO(wxdeng): Handle the case where the task is a class method output. + continue for downstream_task_idx in task.downstream_task_idxs: downstream_dag_node = idx_to_task[downstream_task_idx].dag_node + if isinstance(downstream_dag_node, MultiOutputNode): + continue if ( isinstance(downstream_dag_node, ClassMethodNode) and downstream_dag_node.is_class_method_output @@ -292,12 +382,11 @@ def _build_dag_node_operation_graph( # TODO(wxdeng): Handle the case where the downstream task is # a class method output. continue - if isinstance(downstream_dag_node, MultiOutputNode): - continue _add_edge( graph[task_idx][_DAGNodeOperationType.WRITE], graph[downstream_task_idx][_DAGNodeOperationType.READ], ) + return graph @@ -339,32 +428,44 @@ def _generate_actor_to_execution_schedule( # have been satisfied, including both data and control dependencies. # Therefore, it is a candidate for execution. if node.in_degree == 0: - heapq.heappush(actor_to_candidates[node.actor_handle._actor_id], node) + _push_candidate_node_if_ready(actor_to_candidates, graph, node) visited_nodes = set() - # Use topological sort algorithm to generate the execution schedule. Each iteration - # pops a candidate node from `actor_to_candidates` and each DAG node consists of - # three operations: READ, COMPUTE, and WRITE. - for _ in range(len(graph) * 3): - # The function `_select_next_nodes` will pop a candidate node from - # `actor_to_candidates` and return a list of nodes that can be executed - # in the next step. If multiple nodes are returned, only the NCCL write - # node is popped in this iteration. + # Use topological sort algorithm to generate the execution schedule. + while True: + # Select a list of nodes to be executed. There are three cases: + # 1. If a selected node is not a NCCL operation, only itself is returned. + # 2. If a selected node is a NCCL write operation, the corresponding NCCL + # read operations are also returned. + # 3. If a selected node is a NCCL collective operation, all the nodes in + # its collective operation are returned. + # In cases 1 and 3, all the selected nodes are ready. In case 2, the NCCL + # write node is ready, while the NCCL read nodes are not ready until their + # in-degrees are updated. nodes = _select_next_nodes(actor_to_candidates, graph) + if nodes is None: + break + # Filter out the visited nodes. + nodes = [node for node in nodes if node not in visited_nodes] + # Add the selected nodes to the execution schedule. for node in nodes: - if node in visited_nodes: - continue actor_to_execution_schedule[node.actor_handle].append(node.operation) visited_nodes.add(node) + # Update the in-degree of the downstream nodes. + for node in nodes: for out_node_task_idx, out_node_type in node.out_edges: out_node = graph[out_node_task_idx][out_node_type] out_node.in_edges.remove((node.task_idx, node.operation.type)) - if out_node.in_degree == 0: - heapq.heappush( - actor_to_candidates[out_node.actor_handle._actor_id], - out_node, - ) + if out_node.in_degree == 0 and out_node not in visited_nodes: + # If the downstream node is already visited, it has been added + # to the execution schedule. They are the NCCL read nodes in + # case 2. + _push_candidate_node_if_ready(actor_to_candidates, graph, out_node) + assert len(visited_nodes) == len(graph) * 3, "Expected all nodes to be visited" + for node in visited_nodes: + assert node.is_ready, f"Expected {node} to be ready" for _, candidates in actor_to_candidates.items(): - assert len(candidates) == 0 + assert len(candidates) == 0, "Expected all candidates to be empty" + return actor_to_execution_schedule diff --git a/python/ray/dag/tests/experimental/test_collective_dag.py b/python/ray/dag/tests/experimental/test_collective_dag.py new file mode 100644 index 000000000000..680e6fd27dfb --- /dev/null +++ b/python/ray/dag/tests/experimental/test_collective_dag.py @@ -0,0 +1,530 @@ +# coding: utf-8 +import logging +import os +import sys +import uuid +import copy +from typing import Dict, FrozenSet, List, Optional, Set, Tuple + +import pytest +import ray +import ray.cluster_utils +import ray.experimental.collective as collective +import torch +from ray.dag import InputNode, MultiOutputNode +from ray.experimental.channel.torch_tensor_type import TorchTensorType +from ray.experimental.channel.common import ChannelContext +from ray.experimental.channel.gpu_communicator import ( + GPUCommunicator, + TorchTensorAllocator, +) +from ray.tests.conftest import * # noqa +from ray.util.collective.types import ReduceOp + +logger = logging.getLogger(__name__) + +if sys.platform != "linux" and sys.platform != "darwin": + pytest.skip("Skipping, requires Linux or Mac.", allow_module_level=True) + + +@ray.remote +class CPUTorchTensorWorker: + def __init__(self): + self.device = "cpu" + + def return_tensor(self, size: int) -> torch.Tensor: + return torch.ones(size, device=self.device) + + def recv(self, tensor: torch.Tensor) -> Tuple[int, int]: + assert tensor.device == self.device + return tensor.shape, tensor[0] + + +def mock_do_init_nccl_group( + self, + group_id: str, + rank: int, + actors: List[ray.actor.ActorHandle], + custom_nccl_group: Optional[GPUCommunicator], +) -> None: + ctx = ChannelContext.get_current() + if custom_nccl_group is None: + nccl_group = AbstractNcclGroup(actors) + nccl_group.initialize(rank) + ctx.nccl_groups[group_id] = nccl_group + else: + custom_nccl_group.initialize(rank) + ctx.nccl_groups[group_id] = custom_nccl_group + + +def mock_do_destroy_nccl_group(self, group_id: str) -> None: + ctx = ChannelContext.get_current() + if group_id not in ctx.nccl_groups: + return + ctx.nccl_groups[group_id].destroy() + del ctx.nccl_groups[group_id] + + +class AbstractNcclGroup(GPUCommunicator): + """ + A dummy NCCL group for testing. + """ + + def __init__(self, actor_handles: List[ray.actor.ActorHandle]): + self._actor_handles = actor_handles + self._rank = None + + def initialize(self, rank: int) -> None: + self._rank = rank + + def get_rank(self, actor: ray.actor.ActorHandle) -> int: + return self._actor_handles.index(actor) + + def get_world_size(self) -> int: + return len(self._actor_handles) + + def get_self_rank(self) -> Optional[int]: + return self._rank + + def get_actor_handles(self) -> List["ray.actor.ActorHandle"]: + return self._actor_handles + + def send(self, value: "torch.Tensor", peer_rank: int) -> None: + raise NotImplementedError + + def recv( + self, + shape: Tuple[int], + dtype: "torch.dtype", + peer_rank: int, + allocator: Optional[TorchTensorAllocator] = None, + ) -> "torch.Tensor": + raise NotImplementedError + + def allreduce( + self, + send_buf: "torch.Tensor", + recv_buf: "torch.Tensor", + op: ReduceOp = ReduceOp.SUM, + ) -> None: + raise NotImplementedError + + def destroy(self) -> None: + pass + + +class MockNcclGroupSet: + def __init__(self): + # Represents a mapping from a NCCL group ID to a set of actors and a custom + # NCCL group. + self.ids_to_actors_and_custom_comms: Dict[ + str, Tuple[FrozenSet["ray.actor.ActorHandle"], Optional[GPUCommunicator]] + ] = {} + + def __call__( + self, + actors: List["ray.actor.ActorHandle"], + custom_nccl_group: Optional[GPUCommunicator] = None, + ) -> str: + group_id = str(uuid.uuid4()) + self.ids_to_actors_and_custom_comms[group_id] = ( + frozenset(actors), + custom_nccl_group, + ) + + if custom_nccl_group is None: + ranks = list(range(len(actors))) + else: + ranks = [custom_nccl_group.get_rank(actor) for actor in actors] + init_tasks = [ + actor.__ray_call__.remote( + mock_do_init_nccl_group, + group_id, + rank, + actors, + custom_nccl_group, + ) + for rank, actor in zip(ranks, actors) + ] + ray.get(init_tasks, timeout=30) + + ctx = ChannelContext.get_current() + if custom_nccl_group is not None: + ctx.nccl_groups[group_id] = custom_nccl_group + else: + ctx.nccl_groups[group_id] = AbstractNcclGroup(actors) + + return group_id + + def mock_destroy_nccl_group(self, group_id: str) -> None: + ctx = ChannelContext.get_current() + if group_id not in ctx.nccl_groups: + return + + actors, _ = self.ids_to_actors_and_custom_comms[group_id] + destroy_tasks = [ + actor.__ray_call__.remote( + mock_do_destroy_nccl_group, + group_id, + ) + for actor in actors + ] + ray.wait(destroy_tasks, timeout=30) + + if group_id in self.ids_to_actors_and_custom_comms: + del self.ids_to_actors_and_custom_comms[group_id] + ctx.nccl_groups[group_id].destroy() + del ctx.nccl_groups[group_id] + + def check_init( + self, + compiled_dag: "ray.dag.CompiledDAG", + actors_and_custom_comms: Set[ + Tuple[FrozenSet["ray.actor.ActorHandle"], Optional[GPUCommunicator]] + ], + p2p_actors_and_custom_comm: Optional[ + Tuple[FrozenSet["ray.actor.ActorHandle"], Optional[GPUCommunicator]] + ], + ) -> None: + assert len(self.ids_to_actors_and_custom_comms) == len(actors_and_custom_comms) + assert ( + set(self.ids_to_actors_and_custom_comms.values()) == actors_and_custom_comms + ) + + nccl_group_id_p2p = compiled_dag.nccl_group_id_p2p + if p2p_actors_and_custom_comm is None: + assert nccl_group_id_p2p is None + else: + assert nccl_group_id_p2p + assert ( + self.ids_to_actors_and_custom_comms[nccl_group_id_p2p] + == p2p_actors_and_custom_comm + ) + + def check_teardown(self, nccl_group_ids: List[str]) -> None: + ctx = ChannelContext.get_current() + for nccl_group_id in nccl_group_ids: + assert nccl_group_id not in self.ids_to_actors_and_custom_comms + assert nccl_group_id not in ctx.nccl_groups + + +def check_nccl_group_init( + monkeypatch, + dag: "ray.dag.DAGNode", + actors_and_custom_comms: Set[ + Tuple[FrozenSet["ray.actor.ActorHandle"], Optional[GPUCommunicator]] + ], + p2p_actors_and_custom_comm: Optional[ + Tuple[FrozenSet["ray.actor.ActorHandle"], Optional[GPUCommunicator]] + ] = None, +) -> "ray.dag.CompiledDAG": + mock_nccl_group_set = MockNcclGroupSet() + monkeypatch.setattr( + "ray.dag.compiled_dag_node._init_nccl_group", + mock_nccl_group_set, + ) + monkeypatch.setattr( + "ray.dag.collective_node._init_nccl_group", + mock_nccl_group_set, + ) + + compiled_dag = dag.experimental_compile() + mock_nccl_group_set.check_init( + compiled_dag, + actors_and_custom_comms, + p2p_actors_and_custom_comm, + ) + + return compiled_dag, mock_nccl_group_set + + +def check_nccl_group_teardown( + monkeypatch, + compiled_dag: "ray.dag.CompiledDAG", + mock_nccl_group_set: MockNcclGroupSet, +): + monkeypatch.setattr( + "ray.dag.compiled_dag_node._destroy_nccl_group", + mock_nccl_group_set.mock_destroy_nccl_group, + ) + + nccl_group_ids = copy.deepcopy(compiled_dag.nccl_group_ids) + compiled_dag.teardown() + mock_nccl_group_set.check_teardown(nccl_group_ids) + + +@pytest.mark.parametrize("ray_start_regular", [{"num_cpus": 4}], indirect=True) +def test_all_reduce_duplicate_actors(ray_start_regular): + """ + Test an error is thrown when two input nodes from the same actor bind to + an all-reduce. + """ + actor_cls = CPUTorchTensorWorker.options() + worker = actor_cls.remote() + + with InputNode() as inp: + computes = [worker.return_tensor.bind(inp) for _ in range(2)] + with pytest.raises( + ValueError, + match="Expected unique actor handles for a collective operation", + ): + collective.allreduce.bind(computes) + + with InputNode() as inp: + compute = worker.return_tensor.bind(inp) + computes = [compute for _ in range(2)] + with pytest.raises( + ValueError, + match="Expected unique input nodes for a collective operation", + ): + collective.allreduce.bind(computes) + + +@pytest.mark.parametrize("ray_start_regular", [{"num_cpus": 4}], indirect=True) +def test_all_reduce_custom_comm_wrong_actors(ray_start_regular): + """ + Test an error is thrown when an all-reduce binds to a custom NCCL group and + a wrong set of actors. + """ + actor_cls = CPUTorchTensorWorker.options() + + num_workers = 2 + workers = [actor_cls.remote() for _ in range(num_workers)] + + nccl_group = AbstractNcclGroup([workers[0]]) + with InputNode() as inp: + computes = [worker.return_tensor.bind(inp) for worker in workers] + with pytest.raises( + ValueError, + match="Expected actor handles to match the custom NCCL group", + ): + collective.allreduce.bind(computes, transport=nccl_group) + + +@pytest.mark.parametrize( + "ray_start_regular", [{"num_cpus": 4, "num_gpus": 4}], indirect=True +) +def test_comm_all_reduces(ray_start_regular, monkeypatch): + """ + Test different communicators are used for different all-reduce calls of + different sets of actors. + """ + actor_cls = CPUTorchTensorWorker.options(num_cpus=0, num_gpus=1) + + num_workers = 2 + workers = [actor_cls.remote() for _ in range(num_workers)] + + with InputNode() as inp: + computes = [worker.return_tensor.bind(inp) for worker in workers] + # There are two all-reduces, each on one actor. + collectives = [collective.allreduce.bind([compute]) for compute in computes] + # collective[0] is the only CollectiveOutputNode for each all-reduce. + dag = MultiOutputNode([collective[0] for collective in collectives]) + + compiled_dag, mock_nccl_group_set = check_nccl_group_init( + monkeypatch, + dag, + { + (frozenset([workers[0]]), None), + (frozenset([workers[1]]), None), + }, + ) + + check_nccl_group_teardown(monkeypatch, compiled_dag, mock_nccl_group_set) + + +@pytest.mark.parametrize( + "ray_start_regular", [{"num_cpus": 4, "num_gpus": 4}], indirect=True +) +def test_comm_deduplicate_all_reduces(ray_start_regular, monkeypatch): + """ + Test communicators are deduplicated when all-reduces are called on the same + group of actors more than once. + """ + actor_cls = CPUTorchTensorWorker.options(num_cpus=0, num_gpus=1) + + num_workers = 2 + workers = [actor_cls.remote() for _ in range(num_workers)] + + with InputNode() as inp: + tensors = [worker.return_tensor.bind(inp) for worker in workers] + collectives = collective.allreduce.bind(tensors) + collectives = collective.allreduce.bind(collectives) + dag = MultiOutputNode(collectives) + + compiled_dag, mock_nccl_group_set = check_nccl_group_init( + monkeypatch, + dag, + {(frozenset(workers), None)}, + ) + + check_nccl_group_teardown(monkeypatch, compiled_dag, mock_nccl_group_set) + + +@pytest.mark.parametrize( + "ray_start_regular", [{"num_cpus": 4, "num_gpus": 4}], indirect=True +) +def test_comm_deduplicate_p2p_and_collective(ray_start_regular, monkeypatch): + """ + Test communicators are deduplicated when the collective and the P2P are on + the same set of actors. + """ + actor_cls = CPUTorchTensorWorker.options(num_cpus=0, num_gpus=1) + + num_workers = 2 + workers = [actor_cls.remote() for _ in range(num_workers)] + + with InputNode() as inp: + computes = [worker.return_tensor.bind(inp) for worker in workers] + collectives = collective.allreduce.bind(computes) + recvs = [ + # Each of the 2 workers receives from the other. + workers[0].recv.bind( + collectives[1].with_type_hint(TorchTensorType(transport="nccl")) + ), + workers[1].recv.bind( + collectives[0].with_type_hint(TorchTensorType(transport="nccl")) + ), + ] + dag = MultiOutputNode(recvs) + + compiled_dag, mock_nccl_group_set = check_nccl_group_init( + monkeypatch, + dag, + {(frozenset(workers), None)}, + (frozenset(workers), None), + ) + + check_nccl_group_teardown(monkeypatch, compiled_dag, mock_nccl_group_set) + + with InputNode() as inp: + computes = [worker.return_tensor.bind(inp) for worker in workers] + collectives = collective.allreduce.bind(computes) + # Sender is workers[0] and receiver is workers[1]. + dag = workers[1].recv.bind( + collectives[0].with_type_hint(TorchTensorType(transport="nccl")) + ) + + compiled_dag, mock_nccl_group_set = check_nccl_group_init( + monkeypatch, + dag, + {(frozenset(workers), None)}, + (frozenset(workers), None), + ) + + check_nccl_group_teardown(monkeypatch, compiled_dag, mock_nccl_group_set) + + +@pytest.mark.parametrize( + "ray_start_regular", [{"num_cpus": 4, "num_gpus": 4}], indirect=True +) +def test_custom_comm_deduplicate(ray_start_regular, monkeypatch): + """ + Test a custom GPU communicator is reused when possible. + """ + actor_cls = CPUTorchTensorWorker.options(num_cpus=0, num_gpus=1) + + num_workers = 2 + workers = [actor_cls.remote() for _ in range(num_workers)] + + comm = AbstractNcclGroup(workers) + with InputNode() as inp: + computes = [worker.return_tensor.bind(inp) for worker in workers] + collectives = collective.allreduce.bind(computes, transport=comm) + collectives = collective.allreduce.bind(collectives) + dag = workers[0].recv.bind( + collectives[1].with_type_hint(TorchTensorType(transport="nccl")) + ) + + compiled_dag, mock_nccl_group_set = check_nccl_group_init( + monkeypatch, + dag, + {(frozenset(workers), comm)}, + (frozenset(workers), comm), + ) + + check_nccl_group_teardown(monkeypatch, compiled_dag, mock_nccl_group_set) + + comm = AbstractNcclGroup(workers) + with InputNode() as inp: + computes = [worker.return_tensor.bind(inp) for worker in workers] + collectives = collective.allreduce.bind(computes) + collectives = collective.allreduce.bind(collectives) + dag = workers[0].recv.bind( + collectives[1].with_type_hint(TorchTensorType(transport=comm)) + ) + + compiled_dag, mock_nccl_group_set = check_nccl_group_init( + monkeypatch, + dag, + {(frozenset(workers), comm)}, + (frozenset(workers), comm), + ) + + check_nccl_group_teardown(monkeypatch, compiled_dag, mock_nccl_group_set) + + +@pytest.mark.parametrize( + "ray_start_regular", [{"num_cpus": 4, "num_gpus": 4}], indirect=True +) +def test_custom_comm_init_teardown(ray_start_regular, monkeypatch): + """ + Test custom NCCL groups are properly initialized and destroyed. + 1. Test when multiple type hints have the same `transport=custom_nccl_group`, + the `custom_nccl_group` is initialized only once. + 2. Test all initialized NCCL groups are destroyed during teardown. + """ + actor_cls = CPUTorchTensorWorker.options(num_cpus=0, num_gpus=1) + + num_workers = 2 + workers = [actor_cls.remote() for _ in range(num_workers)] + + comm = AbstractNcclGroup(workers) + + with InputNode() as inp: + tensors = [worker.return_tensor.bind(inp) for worker in workers] + allreduce = collective.allreduce.bind(tensors, transport=comm) + dag = workers[0].recv.bind( + allreduce[1].with_type_hint(TorchTensorType(transport=comm)) + ) + + compiled_dag, mock_nccl_group_set = check_nccl_group_init( + monkeypatch, + dag, + {(frozenset(workers), comm)}, + (frozenset(workers), comm), + ) + + check_nccl_group_teardown(monkeypatch, compiled_dag, mock_nccl_group_set) + + comm_1 = AbstractNcclGroup(workers) + comm_2 = AbstractNcclGroup(workers) + comm_3 = AbstractNcclGroup(workers) + + with InputNode() as inp: + tensors = [worker.return_tensor.bind(inp) for worker in workers] + allreduce1 = collective.allreduce.bind(tensors, transport=comm_1) + allreduce2 = collective.allreduce.bind(allreduce1, transport=comm_2) + dag = workers[0].recv.bind( + allreduce2[1].with_type_hint(TorchTensorType(transport=comm_3)) + ) + + compiled_dag, mock_nccl_group_set = check_nccl_group_init( + monkeypatch, + dag, + { + (frozenset(workers), comm_1), + (frozenset(workers), comm_2), + (frozenset(workers), comm_3), + }, + (frozenset(workers), comm_3), + ) + + check_nccl_group_teardown(monkeypatch, compiled_dag, mock_nccl_group_set) + + +if __name__ == "__main__": + if os.environ.get("PARALLEL_CI"): + sys.exit(pytest.main(["-n", "auto", "--boxed", "-vs", __file__])) + else: + sys.exit(pytest.main(["-sv", __file__])) diff --git a/python/ray/dag/tests/experimental/test_execution_schedule.py b/python/ray/dag/tests/experimental/test_execution_schedule.py index 8633107a2c25..34c8777f0be4 100644 --- a/python/ray/dag/tests/experimental/test_execution_schedule.py +++ b/python/ray/dag/tests/experimental/test_execution_schedule.py @@ -35,18 +35,45 @@ def mock_init(self): pass -def generate_dag_graph_nodes(exec_task_idx, task_idx, actor_handle, requires_nccl): +def generate_dag_graph_nodes( + exec_task_idx, task_idx, actor_handle, requires_nccl, requires_collective=False +): graph_nodes = {} for op_type in _DAGNodeOperationType: + op_requires_nccl = ( + op_type == _DAGNodeOperationType.WRITE and requires_nccl + ) or (op_type == _DAGNodeOperationType.COMPUTE and requires_collective) graph_nodes[op_type] = _DAGOperationGraphNode( _DAGNodeOperation(exec_task_idx, op_type), task_idx, actor_handle, - requires_nccl, + op_requires_nccl, ) return graph_nodes +def set_collective_idxs( + graph: Dict[int, Dict[_DAGNodeOperationType, _DAGOperationGraphNode]], + dag_idxs: List[int], +) -> None: + collective_idxs = {(dag_idx, _DAGNodeOperationType.COMPUTE) for dag_idx in dag_idxs} + for dag_idx in dag_idxs: + graph[dag_idx][_DAGNodeOperationType.COMPUTE].collective_idxs = collective_idxs + + +def set_ready_collective_idxs( + graph: Dict[int, Dict[_DAGNodeOperationType, _DAGOperationGraphNode]], + dag_idxs: List[int], +) -> None: + ready_collective_idxs = { + (dag_idx, _DAGNodeOperationType.COMPUTE) for dag_idx in dag_idxs + } + for dag_idx in dag_idxs: + graph[dag_idx][ + _DAGNodeOperationType.COMPUTE + ].ready_collective_idxs = ready_collective_idxs + + class TestSelectNextNodes: """ Test whether `_select_next_nodes` function selects the next nodes for @@ -228,6 +255,94 @@ def test_two_nccl_writes(self, monkeypatch): ) assert next_nodes[1] == mock_graph[task_idx_2_1][_DAGNodeOperationType.READ] + def test_only_one_nccl_collective(self, monkeypatch): + """ + Simulate the case where there is only one candidate which is a NCCL + collective operation. In this case, `_select_next_nodes` should return + all the NCCL collective nodes. + + driver -> fake_actor_1.allreduce_1 -> driver + | | + -> fake_actor_2.allreduce_1 -> + """ + monkeypatch.setattr(ActorHandle, "__init__", mock_actor_handle_init) + fake_actor_1, dag_idx_1, local_idx_1 = ActorHandle("fake_actor_1"), 1, 0 + fake_actor_2, dag_idx_2, local_idx_2 = ActorHandle("fake_actor_2"), 2, 0 + + mock_graph = { + dag_idx_1: generate_dag_graph_nodes( + local_idx_1, dag_idx_1, fake_actor_1, True, True + ), + dag_idx_2: generate_dag_graph_nodes( + local_idx_2, dag_idx_2, fake_actor_2, True, True + ), + } + set_collective_idxs(mock_graph, [dag_idx_1, dag_idx_2]) + set_ready_collective_idxs(mock_graph, [dag_idx_1, dag_idx_2]) + + mock_actor_to_candidates = { + fake_actor_1: [mock_graph[dag_idx_1][_DAGNodeOperationType.COMPUTE]], + } + next_nodes = _select_next_nodes(mock_actor_to_candidates, mock_graph) + assert set(next_nodes) == { + mock_graph[dag_idx_1][_DAGNodeOperationType.COMPUTE], + mock_graph[dag_idx_2][_DAGNodeOperationType.COMPUTE], + } + + def test_two_nccl_collectives(self, monkeypatch): + """ + Simulate the case where there are two candidates that are NCCL collective + operations. In this case, `_select_next_nodes` should return all the NCCL + collective nodes that are bond earlier. + + driver -> fake_actor_1.allreduce_1 -> driver + | | + -> fake_actor_2.allreduce_1 -> + | | + -> fake_actor_3.allreduce_2 -> + | | + -> fake_actor_4.allreduce_2 -> + """ + monkeypatch.setattr(ActorHandle, "__init__", mock_actor_handle_init) + fake_actor_1, dag_idx_1, local_idx_1 = ActorHandle("fake_actor_1"), 1, 0 + fake_actor_2, dag_idx_2, local_idx_2 = ActorHandle("fake_actor_2"), 2, 0 + fake_actor_3, dag_idx_3, local_idx_3 = ActorHandle("fake_actor_3"), 3, 0 + fake_actor_4, dag_idx_4, local_idx_4 = ActorHandle("fake_actor_4"), 4, 0 + + mock_graph = { + dag_idx_1: generate_dag_graph_nodes( + local_idx_1, dag_idx_1, fake_actor_1, True, True + ), + dag_idx_2: generate_dag_graph_nodes( + local_idx_2, dag_idx_2, fake_actor_2, True, True + ), + dag_idx_3: generate_dag_graph_nodes( + local_idx_3, dag_idx_3, fake_actor_3, True, True + ), + dag_idx_4: generate_dag_graph_nodes( + local_idx_4, dag_idx_4, fake_actor_4, True, True + ), + } + set_collective_idxs(mock_graph, [dag_idx_1, dag_idx_2]) + set_ready_collective_idxs(mock_graph, [dag_idx_1, dag_idx_2]) + set_collective_idxs(mock_graph, [dag_idx_3, dag_idx_4]) + set_ready_collective_idxs(mock_graph, [dag_idx_3, dag_idx_4]) + + mock_actor_to_candidates = { + fake_actor_2: [mock_graph[dag_idx_2][_DAGNodeOperationType.COMPUTE]], + fake_actor_4: [mock_graph[dag_idx_4][_DAGNodeOperationType.COMPUTE]], + } + next_nodes = _select_next_nodes(mock_actor_to_candidates, mock_graph) + assert set(next_nodes) == { + mock_graph[dag_idx_1][_DAGNodeOperationType.COMPUTE], + mock_graph[dag_idx_2][_DAGNodeOperationType.COMPUTE], + } + next_nodes = _select_next_nodes(mock_actor_to_candidates, mock_graph) + assert set(next_nodes) == { + mock_graph[dag_idx_3][_DAGNodeOperationType.COMPUTE], + mock_graph[dag_idx_4][_DAGNodeOperationType.COMPUTE], + } + class TestBuildDAGNodeOperationGraph: """ diff --git a/python/ray/dag/tests/experimental/test_torch_tensor_dag.py b/python/ray/dag/tests/experimental/test_torch_tensor_dag.py index bde6d2d23735..6f6f98db628a 100644 --- a/python/ray/dag/tests/experimental/test_torch_tensor_dag.py +++ b/python/ray/dag/tests/experimental/test_torch_tensor_dag.py @@ -1,30 +1,27 @@ # coding: utf-8 import logging import os +import socket import sys +import time from typing import List, Optional, Tuple + +import pytest +import ray +import ray.cluster_utils +import ray.experimental.collective as collective +import torch +from ray.air._internal import torch_utils +from ray.dag import InputNode, MultiOutputNode +from ray.exceptions import RayChannelError from ray.experimental.channel.gpu_communicator import ( GPUCommunicator, TorchTensorAllocator, ) from ray.experimental.channel.nccl_group import _NcclGroup -import socket -import torch -import time - -import pytest - -from ray.exceptions import RayChannelError -import ray -from ray.air._internal import torch_utils -import ray.cluster_utils -from ray.dag import InputNode +from ray.experimental.channel.torch_tensor_type import TorchTensorType from ray.tests.conftest import * # noqa - -from ray.experimental.channel.torch_tensor_type import ( - TorchTensorType, -) - +from ray.experimental.util.types import ReduceOp logger = logging.getLogger(__name__) @@ -77,6 +74,15 @@ def recv_dict(self, tensor_dict): vals[i] = self.recv(tensor) return vals + def compute_with_tuple_args(self, args, i: int): + shape, dtype, value = args[i] + tensor = torch.ones(shape, dtype=dtype, device=self.device) * value + return tensor + + def recv_tensor(self, tensor): + assert tensor.device == self.device + return tensor + def ping(self): return @@ -97,6 +103,74 @@ def forward(self, inp): return torch.randn(10, 10) +class TestNcclGroup(GPUCommunicator): + """ + A custom NCCL group for testing. This is a simple wrapper around `_NcclGroup`. + """ + + def __init__(self, world_size, comm_id, actor_handles): + self._world_size = world_size + self._comm_id = comm_id + self._actor_handles = actor_handles + self._inner = None + + def initialize(self, rank: int) -> None: + self._inner = _NcclGroup( + self._world_size, + self._comm_id, + rank, + self._actor_handles, + torch.cuda.current_stream().cuda_stream, + ) + + def get_rank(self, actor: ray.actor.ActorHandle) -> int: + # Implement this without forwarding to `_inner` to allow the method + # to be called before initialization. + actor_ids = [a._ray_actor_id for a in self._actor_handles] + try: + rank = actor_ids.index(actor._ray_actor_id) + except ValueError: + raise ValueError("Actor is not in the NCCL group.") + return rank + + def get_world_size(self) -> int: + # Implement this without forwarding to `_inner` to allow the method + # to be called before initialization. + return self._world_size + + def get_self_rank(self) -> Optional[int]: + if self._inner is None: + return None + return self._inner.get_self_rank() + + def get_actor_handles(self) -> List["ray.actor.ActorHandle"]: + return self._actor_handles + + def send(self, value: "torch.Tensor", peer_rank: int) -> None: + return self._inner.send(value, peer_rank) + + def recv( + self, + shape: Tuple[int], + dtype: "torch.dtype", + peer_rank: int, + allocator: Optional[TorchTensorAllocator] = None, + ) -> "torch.Tensor": + return self._inner.recv(shape, dtype, peer_rank, allocator=allocator) + + def allreduce( + self, + send_buf: "torch.Tensor", + recv_buf: "torch.Tensor", + op: ReduceOp = ReduceOp.SUM, + ) -> None: + self._inner.allreduce(send_buf, recv_buf, op) + recv_buf += 1 + + def destroy(self) -> None: + return self._inner.destroy() + + @pytest.mark.parametrize("ray_start_regular", [{"num_cpus": 4}], indirect=True) def test_torch_tensor_p2p(ray_start_regular): if USE_GPU: @@ -316,64 +390,6 @@ def test_torch_tensor_custom_comm(ray_start_regular): sender = actor_cls.remote() receiver = actor_cls.remote() - class TestNcclGroup(GPUCommunicator): - """ - A custom NCCL group for testing. This is a simple wrapper around `_NcclGroup`. - """ - - def __init__(self, world_size, comm_id, actor_handles): - self._world_size = world_size - self._comm_id = comm_id - self._actor_handles = actor_handles - self._inner = None - - def initialize(self, rank: int) -> None: - self._inner = _NcclGroup( - self._world_size, - self._comm_id, - rank, - self._actor_handles, - torch.cuda.current_stream().cuda_stream, - ) - - def get_rank(self, actor: ray.actor.ActorHandle) -> int: - # Implement this without forwarding to `_inner` to allow the method - # to be called before initialization. - actor_ids = [a._ray_actor_id for a in self._actor_handles] - try: - rank = actor_ids.index(actor._ray_actor_id) - except ValueError: - raise ValueError("Actor is not in the NCCL group.") - return rank - - def get_world_size(self) -> int: - # Implement this without forwarding to `_inner` to allow the method - # to be called before initialization. - return self._world_size - - def get_self_rank(self) -> Optional[int]: - if self._inner is None: - return None - return self._inner.get_self_rank() - - def get_actor_handles(self) -> List["ray.actor.ActorHandle"]: - return self._actor_handles - - def send(self, value: "torch.Tensor", peer_rank: int) -> None: - return self._inner.send(value, peer_rank) - - def recv( - self, - shape: Tuple[int], - dtype: "torch.dtype", - peer_rank: int, - allocator: Optional[TorchTensorAllocator] = None, - ) -> "torch.Tensor": - return self._inner.recv(shape, dtype, peer_rank, allocator=allocator) - - def destroy(self) -> None: - return self._inner.destroy() - from cupy.cuda import nccl comm_id = nccl.get_unique_id() @@ -457,6 +473,14 @@ def recv( ) -> "torch.Tensor": return None + def allreduce( + self, + send_buf: "torch.Tensor", + recv_buf: "torch.Tensor", + op: ReduceOp, + ) -> None: + raise NotImplementedError + def destroy(self) -> None: pass @@ -590,10 +614,19 @@ def recv( torch.distributed.recv(tensor, peer_rank) return tensor + def allreduce( + self, + send_buf: "torch.Tensor", + recv_buf: "torch.Tensor", + op: ReduceOp, + ) -> None: + raise NotImplementedError + def destroy(self) -> None: pass nccl_group = InitedNcclGroup(2, [sender, receiver]) + with InputNode() as inp: dag = sender.send_with_tuple_args.bind(inp) dag = dag.with_type_hint(TorchTensorType(transport=nccl_group)) @@ -863,6 +896,259 @@ def test_torch_tensor_exceptions(ray_start_regular): compiled_dag.teardown() +@pytest.mark.parametrize("ray_start_regular", [{"num_cpus": 4}], indirect=True) +def test_torch_tensor_nccl_all_reduce(ray_start_regular): + """ + Test basic all-reduce. + """ + if not USE_GPU: + pytest.skip("NCCL tests require GPUs") + + assert ( + sum(node["Resources"].get("GPU", 0) for node in ray.nodes()) > 1 + ), "This test requires at least 2 GPUs" + + actor_cls = TorchTensorWorker.options(num_cpus=0, num_gpus=1) + + num_workers = 2 + workers = [actor_cls.remote() for _ in range(num_workers)] + + with InputNode() as inp: + computes = [ + worker.compute_with_tuple_args.bind(inp, i) + for i, worker in enumerate(workers) + ] + collectives = collective.allreduce.bind(computes, ReduceOp.SUM) + recvs = [ + worker.recv.bind(collective) + for worker, collective in zip(workers, collectives) + ] + dag = MultiOutputNode(recvs) + + compiled_dag = dag.experimental_compile() + + for i in range(3): + i += 1 + shape = (i * 10,) + dtype = torch.float16 + ref = compiled_dag.execute( + [(shape, dtype, i + idx) for idx in range(num_workers)] + ) + result = ray.get(ref) + reduced_val = sum(i + idx for idx in range(num_workers)) + assert result == [(reduced_val, shape, dtype) for _ in workers] + + compiled_dag.teardown() + + +@pytest.mark.parametrize("ray_start_regular", [{"num_cpus": 4}], indirect=True) +def test_torch_tensor_nccl_all_reduce_get_partial(ray_start_regular): + """ + Test getting partial results from an all-reduce does not hang. + """ + if not USE_GPU: + pytest.skip("NCCL tests require GPUs") + + assert ( + sum(node["Resources"].get("GPU", 0) for node in ray.nodes()) > 1 + ), "This test requires at least 2 GPUs" + + actor_cls = TorchTensorWorker.options(num_cpus=0, num_gpus=1) + + num_workers = 2 + workers = [actor_cls.remote() for _ in range(num_workers)] + + shape = (10,) + dtype = torch.float16 + + with InputNode() as inp: + computes = [ + worker.compute_with_tuple_args.bind(inp, i) + for i, worker in enumerate(workers) + ] + collectives = collective.allreduce.bind(computes, ReduceOp.SUM) + recv = workers[0].recv.bind(collectives[0]) + tensor = workers[1].recv_tensor.bind(collectives[0]) + dag = MultiOutputNode([recv, tensor]) + + compiled_dag = dag.experimental_compile() + + for i in range(3): + ref = compiled_dag.execute( + [(shape, dtype, i + idx + 1) for idx in range(num_workers)] + ) + result = ray.get(ref) + metadata, tensor = result + reduced_val = sum(i + idx + 1 for idx in range(num_workers)) + assert metadata == (reduced_val, shape, dtype) + tensor = tensor.to("cpu") + expected_tensor_val = torch.ones(shape, dtype=dtype) * reduced_val + assert torch.equal(tensor, expected_tensor_val) + + compiled_dag.teardown() + + +@pytest.mark.parametrize("ray_start_regular", [{"num_cpus": 4}], indirect=True) +def test_torch_tensor_nccl_all_reduce_wrong_shape(ray_start_regular): + """ + Test an error is thrown when an all-reduce takes tensors of wrong shapes. + """ + if not USE_GPU: + pytest.skip("NCCL tests require GPUs") + + assert ( + sum(node["Resources"].get("GPU", 0) for node in ray.nodes()) > 1 + ), "This test requires at least 2 GPUs" + + actor_cls = TorchTensorWorker.options(num_cpus=0, num_gpus=1) + + num_workers = 2 + workers = [actor_cls.remote() for _ in range(num_workers)] + + dtype = torch.float16 + + with InputNode() as inp: + computes = [ + worker.compute_with_tuple_args.bind(inp, i) + for i, worker in enumerate(workers) + ] + collectives = collective.allreduce.bind(computes, ReduceOp.SUM) + recvs = [ + worker.recv.bind(collective) + for worker, collective in zip(workers, collectives) + ] + dag = MultiOutputNode(recvs) + + compiled_dag = dag.experimental_compile() + + ref = compiled_dag.execute([((20,), dtype, idx + 1) for idx in range(num_workers)]) + reduced_val = (1 + num_workers) * num_workers / 2 + assert ray.get(ref) == [(reduced_val, (20,), dtype) for _ in range(num_workers)] + + ref = compiled_dag.execute( + [((10 * (idx + 1),), dtype, idx + 1) for idx in range(num_workers)] + ) + # Execution hangs because of shape mismatch and a timeout error is raised. + with pytest.raises(RayChannelError): + ray.get(ref) + + # The DAG will be torn down after any task throws an application-level + # exception, such as when the task returns torch.Tensors of the wrong + # shape or dtype. Check that we can no longer submit to the DAG. + ref = compiled_dag.execute([((20,), dtype, 1) for _ in workers]) + with pytest.raises(RayChannelError): + ref = compiled_dag.execute([((20,), dtype, 1) for _ in workers]) + + compiled_dag.teardown() + + +@pytest.mark.parametrize("ray_start_regular", [{"num_cpus": 4}], indirect=True) +def test_torch_tensor_nccl_all_reduce_custom_comm(ray_start_regular): + """ + Test all-reduce works with a custom communicator. + """ + if not USE_GPU: + pytest.skip("NCCL tests require GPUs") + + assert ( + sum(node["Resources"].get("GPU", 0) for node in ray.nodes()) > 1 + ), "This test requires at least 2 GPUs" + + actor_cls = TorchTensorWorker.options(num_cpus=0, num_gpus=1) + + num_workers = 2 + workers = [actor_cls.remote() for _ in range(num_workers)] + + from cupy.cuda import nccl + + comm_id = nccl.get_unique_id() + nccl_group = TestNcclGroup(2, comm_id, workers) + with InputNode() as inp: + computes = [ + worker.compute_with_tuple_args.bind(inp, i) + for i, worker in enumerate(workers) + ] + collectives = collective.allreduce.bind(computes, transport=nccl_group) + recvs = [ + worker.recv.bind(collective) + for worker, collective in zip(workers, collectives) + ] + dag = MultiOutputNode(recvs) + + compiled_dag = dag.experimental_compile() + + shape = (10,) + dtype = torch.float16 + for i in range(3): + ref = compiled_dag.execute( + [(shape, dtype, i + idx + 1) for idx in range(num_workers)] + ) + result = ray.get(ref) + reduced_val = sum(i + idx + 1 for idx in range(num_workers)) + # The custom communicator adds 1 to the tensor after the all-reduce. + reduced_val += 1 + assert result == [(reduced_val, shape, dtype) for _ in workers] + + compiled_dag.teardown() + + +@pytest.mark.parametrize("ray_start_regular", [{"num_cpus": 4}], indirect=True) +def test_torch_tensor_nccl_all_reduce_scheduling(ray_start_regular): + """ + Test scheduling avoids potential deadlocks that arise from all-reduce operations. + + inp --> x(0) --> +------------+ + | | all-reduce | + --> y(1) --> +------------+ + | + --> t(0) --> recv(1) + + In the above graph, x, y, t are tensors, and the numbers inside parentheses + identify the actors. If actor 1 launches an all-reduce with tensor y while + actor 0 starts sending t, then actor 1 waits for actor 0 to join the all-reduce + while actor 1 waits for actor 0 to receive t. + """ + if not USE_GPU: + pytest.skip("NCCL tests require GPUs") + + assert ( + sum(node["Resources"].get("GPU", 0) for node in ray.nodes()) > 1 + ), "This test requires at least 2 GPUs" + + actor_cls = TorchTensorWorker.options(num_cpus=0, num_gpus=1) + + num_workers = 2 + workers = [actor_cls.remote() for _ in range(num_workers)] + + shape = (10,) + dtype = torch.float16 + with InputNode() as inp: + # Tensors in the all-reduce. + x = workers[0].send.bind(shape, dtype, inp) + y = workers[1].send.bind(shape, dtype, inp) + + # Tensor to be sent from workes[0] to workers[1]. + t = workers[0].send.bind(shape, dtype, inp) + t.with_type_hint(TorchTensorType(transport="nccl")) + + collectives = collective.allreduce.bind([x, y]) + recv = workers[1].recv.bind(t) + dag = MultiOutputNode([collectives[0], collectives[1], recv]) + + compiled_dag = dag.experimental_compile() + + value = 10 + ref = compiled_dag.execute(value) + result = ray.get(ref) + reduced_value = value * 2 + expected_tensor_val = torch.ones(shape, dtype=dtype) * reduced_value + assert torch.equal(result[0], expected_tensor_val) + assert torch.equal(result[1], expected_tensor_val) + assert result[2] == (value, shape, dtype) + + compiled_dag.teardown() + + if __name__ == "__main__": if os.environ.get("PARALLEL_CI"): sys.exit(pytest.main(["-n", "auto", "--boxed", "-vs", __file__])) diff --git a/python/ray/experimental/channel/common.py b/python/ray/experimental/channel/common.py index 279c1a297fed..84d5d5a6c111 100644 --- a/python/ray/experimental/channel/common.py +++ b/python/ray/experimental/channel/common.py @@ -240,6 +240,7 @@ def read(self, timeout: Optional[float] = None) -> Any: Any: The deserialized value. If the deserialized value is an Exception, it will be returned directly instead of being raised. """ + raise NotImplementedError def close(self) -> None: """ diff --git a/python/ray/experimental/channel/gpu_communicator.py b/python/ray/experimental/channel/gpu_communicator.py index e6bc2fccdb2d..26cae2ff9409 100644 --- a/python/ray/experimental/channel/gpu_communicator.py +++ b/python/ray/experimental/channel/gpu_communicator.py @@ -2,6 +2,7 @@ from typing import TYPE_CHECKING, Callable, List, Optional, Tuple import ray +from ray.experimental.util.types import ReduceOp from ray.util.annotations import DeveloperAPI if TYPE_CHECKING: @@ -104,6 +105,24 @@ def recv( """ raise NotImplementedError + @abstractmethod + def allreduce( + self, + send_buf: "torch.Tensor", + recv_buf: "torch.Tensor", + op: ReduceOp, + ) -> None: + """ + Collectively allreduce the tensor across the group. + + Args: + send_buf: The input torch.tensor to allreduce. It should already be + on this actor's default device. + recv_buf: The output torch.tensor to store the allreduce result. + op: The reduce operation. + """ + raise NotImplementedError + @abstractmethod def destroy() -> None: """ diff --git a/python/ray/experimental/channel/nccl_group.py b/python/ray/experimental/channel/nccl_group.py index dcdfef10f163..8f4804848323 100644 --- a/python/ray/experimental/channel/nccl_group.py +++ b/python/ray/experimental/channel/nccl_group.py @@ -8,6 +8,7 @@ GPUCommunicator, TorchTensorAllocator, ) +from ray.experimental.util.types import ReduceOp if TYPE_CHECKING: import cupy as cp @@ -138,19 +139,19 @@ def get_world_size(self) -> int: """ return self._world_size - def send(self, value: "torch.Tensor", peer_rank: int) -> None: + def send(self, buf: "torch.Tensor", peer_rank: int) -> None: """ Send a torch.Tensor to a peer. This returns when the send kernel has been queued, but the kernel may not have completed. Therefore, the caller should ensure that there are - no concurrent writes to the sent `value` until the send has finished. + no concurrent writes to the sent `buf` until the send has finished. That is, either all writes should be submitted on the current stream (self._cuda_stream) or, if on a different stream, that stream should synchronize with the current stream. Args: - value: The torch.Tensor to send. It should already be on this + buf: The torch.Tensor to send. It should already be on this actor's default device. peer_rank: The rank of the actor to send to. """ @@ -159,9 +160,9 @@ def send(self, value: "torch.Tensor", peer_rank: int) -> None: # TODO(swang): Handle send/recv async NCCL errors such as network # failures. self._comm.send( - self.nccl_util.get_tensor_ptr(value), - value.numel(), - self.nccl_util.get_nccl_tensor_dtype(value), + self.nccl_util.get_tensor_ptr(buf), + buf.numel(), + self.nccl_util.get_nccl_tensor_dtype(buf), peer_rank, self._cuda_stream.ptr, ) @@ -205,6 +206,33 @@ def recv( raise RayChannelError("NCCL group has been destroyed.") return buf + def allreduce( + self, + send_buf: "torch.Tensor", + recv_buf: "torch.Tensor", + op: ReduceOp = ReduceOp.SUM, + ): + if self._closed: + raise RayChannelError("NCCL group has been destroyed.") + + self._comm.allReduce( + self.nccl_util.get_tensor_ptr(send_buf), + self.nccl_util.get_tensor_ptr(recv_buf), + send_buf.numel(), + self.nccl_util.get_nccl_tensor_dtype(send_buf), + op.value, + self._cuda_stream.ptr, + ) + + # Buffer values are undefined if NCCL ops are aborted. Therefore, we + # need to synchronize here and check that the channel is still open to + # ensure that the receive buffer is valid. + # TODO(swang): Avoid CUDA synchronization. + # TODO(wxdeng): Use check_async_error. + self._cuda_stream.synchronize() + if self._closed: + raise RayChannelError("NCCL group has been destroyed.") + def destroy(self) -> None: """ Destroy the NCCL group. diff --git a/python/ray/experimental/channel/torch_tensor_type.py b/python/ray/experimental/channel/torch_tensor_type.py index e677488f71fa..53f231af592e 100644 --- a/python/ray/experimental/channel/torch_tensor_type.py +++ b/python/ray/experimental/channel/torch_tensor_type.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, List, Optional, Tuple, Union import ray -from ray.experimental.channel import ChannelContext, ChannelOutputType +from ray.experimental.channel import ChannelContext, ChannelInterface, ChannelOutputType from ray.experimental.channel.gpu_communicator import ( GPUCommunicator, TorchTensorAllocator, @@ -126,8 +126,7 @@ def create_channel( reader_and_node_list: List[Tuple["ray.actor.ActorHandle", str]], read_by_adag_driver, _torch_tensor_allocator: Optional["TorchTensorAllocator"] = None, - ) -> type: - + ) -> ChannelInterface: if self.requires_nccl(): from ray.experimental.channel.torch_tensor_nccl_channel import ( TorchTensorNcclChannel, @@ -189,5 +188,21 @@ def set_nccl_group_id(self, group_id: str) -> None: self._nccl_group_id = group_id @property - def nccl_group_id(self) -> str: + def nccl_group_id(self) -> Optional[str]: return self._nccl_group_id + + def __deepcopy__(self, memo): + """ + Deep copy all the fields except for the custom NCCL group. The custom + NCCL group should not be deep copied because it can be shared across + `TorchTensorType` instances. + """ + copy = TorchTensorType( + _shape=self._shape, + _dtype=self._dtype, + transport=self.transport, + _direct_return=self._direct_return, + ) + copy._custom_nccl_group = self._custom_nccl_group + copy._nccl_group_id = self._nccl_group_id + return copy diff --git a/python/ray/experimental/collective/__init__.py b/python/ray/experimental/collective/__init__.py new file mode 100644 index 000000000000..e824152b7473 --- /dev/null +++ b/python/ray/experimental/collective/__init__.py @@ -0,0 +1,3 @@ +from ray.experimental.collective.allreduce import allreduce + +__all__ = ["allreduce"] diff --git a/python/ray/experimental/collective/allreduce.py b/python/ray/experimental/collective/allreduce.py new file mode 100644 index 000000000000..b00f36a401c0 --- /dev/null +++ b/python/ray/experimental/collective/allreduce.py @@ -0,0 +1,92 @@ +import logging +from typing import List, Optional, Union + +import ray +from ray.dag.collective_node import CollectiveOutputNode, _CollectiveOperation +from ray.dag.constants import ( + BIND_INDEX_KEY, + COLLECTIVE_OPERATION_KEY, + PARENT_CLASS_NODE_KEY, +) +from ray.experimental.channel.torch_tensor_type import GPUCommunicator, TorchTensorType +from ray.experimental.util.types import ReduceOp +from ray.util.collective.types import ReduceOp as RayReduceOp + +# TODO(wxdeng): Unify `ReduceOp` and `RayReduceOp`. Directly importing `RayReduceOp` +# has dependency issues for some tests. + +logger = logging.getLogger(__name__) + + +class AllReduceWrapper: + """Wrapper for NCCL all-reduce.""" + + def bind( + self, + input_nodes: List["ray.dag.DAGNode"], + op: ReduceOp = ReduceOp.SUM, + transport: Optional[Union[str, GPUCommunicator]] = None, + ) -> List[CollectiveOutputNode]: + """ + Bind input nodes with a collective operation. The collective operation is + directly applied to the torch tensors from the input nodes. The output nodes + are the results of the collective operation in the same torch tensors. + + Requirements: + 1. Each input node returns a torch tensor. + 2. Each input node is from a different actor. + 3. If a custom transport is specified, its actor set matches the actor set + of the input nodes. + 4. All tensors have the same shape. + + Requirements 1-3 are checked in the `CollectiveGroup` constructor. + Requirement 4 is not checked yet. + + Args: + input_nodes: A list of DAG nodes. + op: The collective operation. + transport: GPU communicator for the collective operation. If not + specified, the default NCCL is used. + + Returns: + A list of collective output nodes. + """ + if transport is None: + transport = TorchTensorType.NCCL + collective_op = _CollectiveOperation(input_nodes, op, transport) + collective_output_nodes: List[CollectiveOutputNode] = [] + + for input_node in input_nodes: + actor_handle: Optional[ + "ray.actor.ActorHandle" + ] = input_node._get_actor_handle() + if actor_handle is None: + raise ValueError("Expected an actor handle from the input node") + collective_output_node = CollectiveOutputNode( + method_name=f"allreduce.{op}", + method_args=(input_node,), + method_kwargs=dict(), + method_options=dict(), + other_args_to_resolve={ + PARENT_CLASS_NODE_KEY: actor_handle, + BIND_INDEX_KEY: actor_handle._ray_dag_bind_index, + COLLECTIVE_OPERATION_KEY: collective_op, + }, + ) + actor_handle._ray_dag_bind_index += 1 + collective_output_nodes.append(collective_output_node) + + return collective_output_nodes + + def __call__( + self, + tensor, + group_name: str = "default", + op: RayReduceOp = RayReduceOp.SUM, + ): + from ray.util.collective.collective import allreduce + + return allreduce(tensor, group_name, op) + + +allreduce = AllReduceWrapper() diff --git a/python/ray/experimental/util/__init__.py b/python/ray/experimental/util/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/python/ray/experimental/util/types.py b/python/ray/experimental/util/types.py new file mode 100644 index 000000000000..abbb5968dc06 --- /dev/null +++ b/python/ray/experimental/util/types.py @@ -0,0 +1,19 @@ +from enum import Enum + +from ray.util.annotations import PublicAPI + + +class _CollectiveOp(Enum): + pass + + +@PublicAPI +class ReduceOp(_CollectiveOp): + SUM = 0 + PRODUCT = 1 + MAX = 2 + MIN = 3 + AVG = 4 + + def __str__(self): + return f"{self.name.lower()}"