diff --git a/python/ray/dag/tests/experimental/test_torch_tensor_dag.py b/python/ray/dag/tests/experimental/test_torch_tensor_dag.py index 1797068e7e2d..155ff3d875e0 100644 --- a/python/ray/dag/tests/experimental/test_torch_tensor_dag.py +++ b/python/ray/dag/tests/experimental/test_torch_tensor_dag.py @@ -23,6 +23,9 @@ from ray.experimental.channel.torch_tensor_type import TorchTensorType from ray.tests.conftest import * # noqa + +from ray.air._internal.device_manager.npu import NPU_TORCH_PACKAGE_AVAILABLE + from ray.experimental.util.types import ReduceOp logger = logging.getLogger(__name__) @@ -1227,6 +1230,87 @@ def test_torch_tensor_nccl_all_reduce_scheduling(ray_start_regular): assert result[2] == (value, shape, dtype) +NPU_DEVICES = "0,1,2,3,4,5,6,7" + + +@ray.remote(resources={"NPU": 1}) +class TorchTensorWorkerNPU: + # NOTE(zhilong): To run NPU test, we need to change + # "from ray.experimental.channel.nccl_group import _NcclGroup" + # to "from ray.experimental.channel.hccl_group import _HcclGroup" + # in "python/ray/experimental/channel/torch_tensor_nccl_channel.py" + # and also disable All GPU device check. + + # TODO(zhilong): Refactor the aDAG channel so it support different + # XPUs. + + def __init__(self, rank): + import torch # noqa: F401 + + os.environ["ASCEND_RT_VISIBLE_DEVICES"] = NPU_DEVICES + import torch_npu + + self.rank = rank + torch_npu.npu.set_device(rank) + + def send(self, shape, dtype, value: int): + import torch + + os.environ["ASCEND_RT_VISIBLE_DEVICES"] = NPU_DEVICES + import torch_npu + + # May need to import twice to keep the context, + # otherwise it will lose the ctx. + # Different from nccl with cupy, NPU channel relies on torch, + # so we need to keep the torch ctx. + # Create and return a tensor filled with 'value' on the current NPU + torch_npu.npu.set_device(self.rank) + tensor = torch.ones(shape, dtype=dtype) * value + return tensor.to(f"npu:{self.rank}") + + def recv(self, tensor): + # Verify the tensor is on the correct device and return it as CPU tensor + tensor = tensor.cpu() + return (tensor[0].item(), tensor.shape, tensor.dtype) + + +@pytest.mark.parametrize("ray_start_regular", [{"num_cpus": 4}], indirect=True) +def test_torch_tensor_npu_communication(ray_start_regular): + if not NPU_TORCH_PACKAGE_AVAILABLE: + pytest.skip("This test requires NPUs.") + + assert ( + sum(node["Resources"].get("NPU", 0) for node in ray.nodes()) > 1 + ), "This test requires at least 2 NPUs" + + # Initialize actor class with NPU support + actor_cls = TorchTensorWorkerNPU + sender = actor_cls.remote(0) + receiver = actor_cls.remote(1) + + shape = (10,) + dtype = torch.float16 + + # Define the DAG with NPU actors + with InputNode() as inp: + dag = sender.send.bind(shape, dtype, inp) + # Can use with hccl after PR 47845 merged + dag = dag.with_type_hint( + TorchTensorType(shape, dtype, transport="hccl", _direct_return=True) + ) + dag = receiver.recv.bind(dag) + + compiled_dag = dag.experimental_compile() + + # Test tensor sending and receiving on NPUs + for i in range(3): + ref = compiled_dag.execute(i) + result = ray.get(ref) + assert result == (i, shape, dtype) + + compiled_dag.teardown() + + if __name__ == "__main__": if os.environ.get("PARALLEL_CI"): sys.exit(pytest.main(["-n", "auto", "--boxed", "-vs", __file__])) diff --git a/python/ray/experimental/channel/hccl_group.py b/python/ray/experimental/channel/hccl_group.py new file mode 100644 index 000000000000..7da79f5b8b82 --- /dev/null +++ b/python/ray/experimental/channel/hccl_group.py @@ -0,0 +1,202 @@ +import logging +import os +from typing import Optional + +import torch +import torch.distributed as dist +import torch_npu # The torch_npu for communicate + +import ray +from ray.exceptions import RayChannelError +from ray.experimental.channel.gpu_communicator import ( + GPUCommunicator, + TorchTensorAllocator, +) +from ray.experimental.util.types import ReduceOp + +# Set ASCEND_RT_VISIBLE_DEVICES environment variable to ensure all NPUs are visible +# This enables NPU to NPU communication across devices. +# Explaination: Since currently the worker can only see the GPU/NPU asign to +# that worker, the NPU needs to see all NPUs to enable the communication channel. +os.environ["ASCEND_RT_VISIBLE_DEVICES"] = "0,1,2,3,4,5,6,7" + +logger = logging.getLogger(__name__) + + +class _HcclGroup(GPUCommunicator): + """ + Represents an actor's HCCL communicator using NPUs. + + This is the default HCCL communicator to be used in aDAG if a + custom communicator is not provided. + + This class is not thread-safe. + """ + + def __init__( + self, + world_size: int, + comm_id: int, + rank: int, + actor_handles: list, + cuda_stream: Optional[int], + ): + # TODO(zhilong): Change cuda_stream to more general name like "stream". + """ + Initialize an HCCL communicator that can be used to communicate p2p with + other NPU actors. + + This method blocks until the same call has been made on all other + actors in the group, with the same arguments for world_size and comm_id. + + Args: + world_size: The number of participating actors/devices. + comm_id: A unique communicator ID. + rank: The rank of this actor. If None, then the caller is not a + participant of the HCCL group. + actor_handles: A list of actor handles, in rank order. + cuda_stream: Not used here but to keep same agrs with nccl_group. + """ + self._world_size = world_size + self._comm_id = comm_id + self._rank = rank + self._actor_handles = actor_handles + self._closed = False + # Initialize distributed HCCL communication if rank is provided + if rank is not None: + self._init_dist_hccl(rank, world_size) + + def _init_dist_hccl(self, rank, world_size): + """ + Initialize the HCCL communication group on NPUs. + + Args: + rank: The rank of the current process. + world_size: The total number of processes participating + in the communication. + """ + # Set environment variables if not already set + os.environ["MASTER_ADDR"] = os.environ.get("MASTER_ADDR", "127.0.0.1") + os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "29500") + os.environ["HCCL_WHITELIST_DISABLE"] = os.environ.get( + "HCCL_WHITELIST_DISABLE", "1" + ) + + torch_npu.npu.set_device(rank) # Set the NPU device according to the rank + self.ctx = dist.init_process_group( + backend="hccl", world_size=world_size, rank=rank + ) + + def initialize(self, rank: int) -> None: + pass # No additional initialization needed for HCCL group + + def get_actor_handles(self) -> list: + """ + Return the list of actor handles. + + Returns: + list: Actor handles in rank order. + """ + return self._actor_handles + + def get_rank(self, actor: "ray.actor.ActorHandle") -> int: + """ + Return the given actor's rank in the HCCL communicator. + + Args: + actor: The actor handle to look up. + + Returns: + int: The rank of the actor. + """ + actor_ids = [a._ray_actor_id for a in self._actor_handles] + try: + rank = actor_ids.index(actor._ray_actor_id) + except ValueError: + raise ValueError("Actor is not in the HCCL group.") + return rank + + def get_self_rank(self) -> int: + """ + Return this actor's rank. + + Returns: + int: The rank of this actor in the HCCL group. + """ + return self._rank + + def get_world_size(self) -> int: + """ + Return the number of ranks in the HCCL communicator. + + Returns: + int: The world size of the HCCL group. + """ + return self._world_size + + def send(self, tensor: "torch.Tensor", peer_rank: int) -> None: + """ + Send a tensor to a peer using HCCL. + + Args: + tensor: The tensor to be sent. + peer_rank: The rank of the peer to send the tensor to. + """ + if self._closed: + raise RuntimeError("HCCL group has been destroyed.") + logger.info(f"Start to send to:{peer_rank}, self._rank : {self._rank} ") + dist.send(tensor, dst=peer_rank) + + def recv( + self, + shape: tuple, + dtype: "torch.dtype", + peer_rank: int, + allocator: Optional[TorchTensorAllocator], + ) -> "torch.Tensor": + """ + Receive a tensor from a peer using HCCL. + + Args: + shape: The shape of the tensor to receive. + dtype: The data type of the tensor. + peer_rank: The rank of the peer to receive the tensor from. + allocator: Optional allocator to allocate memory for the tensor. + + Returns: + torch.Tensor: The received tensor. + """ + if self._closed: + raise RuntimeError("HCCL group has been destroyed.") + torch_npu.npu.set_device(f"npu:{self._rank}") + tensor = torch.zeros(*shape, dtype=dtype).to(f"npu:{self._rank}") + dist.recv(tensor, src=peer_rank) + if self._closed: + raise RayChannelError("HCCL group has been destroyed.") + return tensor + + def recv_stream(self): + pass + + def send_stream(self): + pass + + def allreduce( + self, + send_buf: "torch.Tensor", + recv_buf: "torch.Tensor", + op: ReduceOp, + ) -> None: + pass + + def destroy(self) -> None: + """ + Destroy the HCCL group and clean up resources. + """ + self._closed = True + dist.destroy_process_group() + if self._rank is not None: + logger.info( + "Destructing HCCL group on actor: " + f"{ray.get_runtime_context().current_actor}" + ) diff --git a/python/ray/experimental/channel/torch_tensor_nccl_channel.py b/python/ray/experimental/channel/torch_tensor_nccl_channel.py index 4b8baedeebd0..7ab4d711fa07 100644 --- a/python/ray/experimental/channel/torch_tensor_nccl_channel.py +++ b/python/ray/experimental/channel/torch_tensor_nccl_channel.py @@ -1,5 +1,6 @@ import io import logging +import os import uuid from dataclasses import dataclass from types import ModuleType @@ -10,7 +11,6 @@ from ray.experimental.channel import ChannelContext from ray.experimental.channel.common import ChannelInterface from ray.experimental.channel.gpu_communicator import GPUCommunicator -from ray.experimental.channel.nccl_group import _NcclGroup from ray.experimental.channel.shared_memory_channel import SharedMemoryType from ray.experimental.channel.torch_tensor_type import TorchTensorType from ray.util.annotations import DeveloperAPI @@ -25,6 +25,20 @@ # into the program using Ray. Ray provides a default configuration at # entry/init points. logger = logging.getLogger(__name__) +USE_GPU = True +USE_NPU = False +if os.getenv("ASCEND_RT_VISIBLE_DEVICES"): + try: + from ray.experimental.channel.hccl_group import _HcclGroup as _NcclGroup + + USE_GPU = False + USE_NPU = True + except Exception: + logger.warning("Failed in import hccl_group, use nccl_group instead") + from ray.experimental.channel.nccl_group import _NcclGroup + +else: + from ray.experimental.channel.nccl_group import _NcclGroup @dataclass @@ -187,12 +201,18 @@ def write(self, value: Any, timeout: Optional[float] = None) -> None: "return a CUDA torch.Tensor, instead found value " f"`{value}`. DAG will shut down." ) - elif not value.is_cuda: + if USE_GPU and (not value.is_cuda): raise ValueError( "Task annotated with _direct_return=True must " "return a CUDA torch.Tensor, instead found CPU tensor. " "DAG will shut down." ) + elif USE_NPU and (not value.is_npu): + raise ValueError( + "Task annotated with _direct_return=True must " + "return a NPU torch.Tensor, instead found CPU tensor. " + "DAG will shut down." + ) self._gpu_data_channel.write([value], timeout=timeout) else: self._send_cpu_and_gpu_data(value, timeout) @@ -361,13 +381,13 @@ def ensure_registered_as_writer(self): assert self._nccl_group is not None, "Actor is not part of a NCCL group" assert self._writer_registered ctx = ChannelContext.get_current() - assert ctx.torch_device.type == "cuda" + assert ctx.torch_device.type in ["cuda", "npu"] def ensure_registered_as_reader(self) -> bool: assert self._nccl_group is not None, "Actor is not part of a NCCL group" assert self._reader_registered ctx = ChannelContext.get_current() - assert ctx.torch_device.type == "cuda" + assert ctx.torch_device.type in ["cuda", "npu"] def __reduce__(self): return ( @@ -553,15 +573,15 @@ def _do_init_nccl_group( ): import torch - assert ( - ray.get_gpu_ids() - ), "Actors participating in NCCL group must have at least one GPU assigned" + assert bool(ray.get_gpu_ids()) or bool( + "NPU" in ray.cluster_resources() + ), "Actors participating in Communicator group must have at least one XPU assigned" ctx = ChannelContext.get_current() if custom_nccl_group is not None: custom_nccl_group.initialize(rank) ctx.nccl_groups[group_id] = custom_nccl_group - else: + elif USE_GPU: ctx.nccl_groups[group_id] = _NcclGroup( world_size, comm_id, @@ -570,6 +590,14 @@ def _do_init_nccl_group( torch.cuda.current_stream().cuda_stream, use_communication_streams, ) + else: + ctx.nccl_groups[group_id] = _NcclGroup( + world_size, + comm_id, + rank, + actor_handles, + None, + ) def _do_destroy_nccl_group(self, group_id): @@ -584,13 +612,18 @@ def _do_destroy_nccl_group(self, group_id): def _do_check_has_gpu(self) -> bool: - return bool(ray.get_gpu_ids()) + # Check for GPU or NPU + return bool(ray.get_gpu_ids()) or bool("NPU" in ray.cluster_resources()) def _do_get_unique_nccl_id(self) -> bool: - from cupy.cuda import nccl + if "NPU" in ray.cluster_resources(): + # NPU doesn't have get_unique_id + return uuid.uuid4() + else: + from cupy.cuda import nccl - return nccl.get_unique_id() + return nccl.get_unique_id() def _get_ranks( diff --git a/python/ray/experimental/channel/torch_tensor_type.py b/python/ray/experimental/channel/torch_tensor_type.py index 8615f18b7d65..cadb15f8b38e 100644 --- a/python/ray/experimental/channel/torch_tensor_type.py +++ b/python/ray/experimental/channel/torch_tensor_type.py @@ -17,6 +17,8 @@ class TorchTensorType(ChannelOutputType): AUTO = "auto" NCCL = "nccl" + HCCL = "hccl" + COMMUNICATOR_TYPES = [NCCL, HCCL] def __init__( self, @@ -69,9 +71,10 @@ def __init__( self._custom_nccl_group = transport transport = self.NCCL - if transport not in [self.AUTO, self.NCCL]: + if transport not in [self.AUTO, self.NCCL, self.HCCL]: raise ValueError( - "`transport` must be TorchTensorType.AUTO or TorchTensorType.NCCL" + "`transport` must be TorchTensorType.AUTO, " + "TorchTensorType.NCCL, or TorchTensorType.HCCL" ) self.transport = transport @@ -162,7 +165,7 @@ def create_channel( return typ.create_channel(writer, reader_and_node_list, read_by_adag_driver) def requires_nccl(self) -> bool: - return self.transport == self.NCCL + return self.transport in self.COMMUNICATOR_TYPES def get_custom_nccl_group(self) -> Optional[GPUCommunicator]: """