Skip to content

Commit

Permalink
torch.compile: register all-reduce operations as custom ops (#1050)
Browse files Browse the repository at this point in the history
  • Loading branch information
AlpinDale authored Dec 27, 2024
1 parent 4593a3b commit 239a8ca
Show file tree
Hide file tree
Showing 8 changed files with 117 additions and 42 deletions.
6 changes: 0 additions & 6 deletions aphrodite/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -901,12 +901,6 @@ def init_custom_ar(meta: torch.Tensor, rank_data: torch.Tensor,
offsets, rank, full_nvlink)


def should_custom_ar(inp: torch.Tensor, max_size: int, world_size: int,
full_nvlink: bool) -> bool:
return torch.ops._C_custom_ar.should_custom_ar(inp, max_size, world_size,
full_nvlink)


def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None:
torch.ops._C_custom_ar.all_reduce_reg(fa, inp, out)

Expand Down
21 changes: 19 additions & 2 deletions aphrodite/distributed/device_communicators/custom_all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ def _can_p2p(rank: int, world_size: int) -> bool:
return True


def is_weak_contiguous(inp: torch.Tensor):
return inp.is_contiguous() or (inp.storage().nbytes() -
inp.storage_offset() * inp.element_size()
== inp.numel() * inp.element_size())


class CustomAllreduce:

_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
Expand Down Expand Up @@ -230,8 +236,19 @@ def register_graph_buffers(self):
ops.register_graph_buffers(self._ptr, handles, offsets)

def should_custom_ar(self, inp: torch.Tensor):
return ops.should_custom_ar(inp, self.max_size, self.world_size,
self.full_nvlink)
if self.disabled:
return False
inp_size = inp.numel() * inp.element_size()
# custom allreduce requires input byte size to be multiples of 16
if inp_size % 16 != 0:
return False
if not is_weak_contiguous(inp):
return False
# for 4 or more non NVLink-capable GPUs, custom allreduce provides
# little performance improvement over NCCL.
if self.world_size == 2 or self.full_nvlink:
return inp_size < self.max_size
return False

# all reduce, assuming inp tensor is IPC registered with register_buffer,
# or, in the context of cuda graphs, register_graph_buffers
Expand Down
101 changes: 87 additions & 14 deletions aphrodite/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,12 @@
import contextlib
import pickle
import sys
import weakref
from collections import namedtuple
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
from multiprocessing import shared_memory
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from unittest.mock import patch

import torch
Expand Down Expand Up @@ -71,6 +72,48 @@ def _split_tensor_dict(
return metadata_list, tensor_list


_group_name_counter: Dict[str, int] = {}
def _get_unique_name(name: str) -> str:
"""Get a unique name for the group.
Example:
_get_unique_name("tp") -> "tp:0"
_get_unique_name("tp") -> "tp:1"
"""
if name not in _group_name_counter:
_group_name_counter[name] = 0
newname = f"{name}:{_group_name_counter[name]}"
_group_name_counter[name] += 1
return newname
_groups: Dict[str, Callable[[], "GroupCoordinator"]] = {}

def _register_group(group: "GroupCoordinator") -> None:
# looks like Python 3.8 does not understand `ReferenceType`
_groups[group.unique_name] = weakref.ref(group) # type: ignore

@torch.library.custom_op("vllm::inplace_all_reduce", mutates_args=["tensor"])
def inplace_all_reduce(tensor: torch.Tensor, group_name: str) -> None:
assert group_name in _groups, f"Group {group_name} is not found."
group = _groups[group_name]()
if group is None:
raise ValueError(f"Group {group_name} is destroyed.")
group._all_reduce(tensor)

@inplace_all_reduce.register_fake
def _(tensor: torch.Tensor, group_name: str) -> None:
return
@torch.library.custom_op("vllm::outplace_all_reduce", mutates_args=[])
def outplace_all_reduce(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
assert group_name in _groups, f"Group {group_name} is not found."
group = _groups[group_name]()
if group is None:
raise ValueError(f"Group {group_name} is destroyed.")
return group._all_reduce(tensor)

@outplace_all_reduce.register_fake
def _(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
return torch.empty_like(tensor)


class GroupCoordinator:
"""
PyTorch ProcessGroup wrapper for a group of processes.
Expand Down Expand Up @@ -113,7 +156,11 @@ def __init__(
use_custom_allreduce: bool,
use_tpu_communicator: bool,
use_message_queue_broadcaster: bool = False,
group_name: Optional[str] = None,
):
group_name = group_name or "anonymous"
self.unique_name = _get_unique_name(group_name)
_register_group(self)

self.rank = torch.distributed.get_rank()
self.local_rank = local_rank
Expand Down Expand Up @@ -151,28 +198,24 @@ def __init__(
from aphrodite.distributed.device_communicators.pynccl import (
PyNcclCommunicator)

self.pynccl_comm: Optional[PyNcclCommunicator]
self.pynccl_comm: Optional[PyNcclCommunicator] = None
if use_pynccl and self.world_size > 1:
self.pynccl_comm = PyNcclCommunicator(
group=self.cpu_group,
device=self.device,
)
else:
self.pynccl_comm = None

self.ca_comm: Optional[CustomAllreduce]
self.ca_comm: Optional[CustomAllreduce] = None
if use_custom_allreduce and self.world_size > 1:
# Initialize a custom fast all-reduce implementation.
self.ca_comm = CustomAllreduce(
group=self.cpu_group,
device=self.device,
)
else:
self.ca_comm = None

from aphrodite.distributed.device_communicators.tpu_communicator import ( # noqa: E501
TpuCommunicator)
self.tpu_communicator: Optional[TpuCommunicator]
self.tpu_communicator: Optional[TpuCommunicator] = None
if use_tpu_communicator and self.world_size > 1:
self.tpu_communicator = TpuCommunicator(group=self.cpu_group)

Expand Down Expand Up @@ -266,16 +309,41 @@ def graph_capture(

def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
"""
User-facing all-reduce function before we actually call the
all-reduce operation.
We need this because Dynamo does not support passing an arbitrary
object (`self` in this case) to a custom op. We need to pass the
group name as a string, and then look up the group coordinator from
the group name, dispatch the all-reduce operation to the group
coordinator.
In addition, PyTorch custom ops do not support mutation or returning
a new tensor in the same op. So we need to figure out if the op is
in-place or out-of-place ahead of time.
"""
# Bypass the function if we are using only 1 GPU.
if self.world_size == 1:
return input_
if self.tpu_communicator is not None and \
not self.tpu_communicator.disabled:
# TPU handles Dynamo with its own logic.
return self._all_reduce(input_)
if self.ca_comm is not None and self.ca_comm.should_custom_ar(input_):
return torch.ops.aphrodite.outplace_all_reduce(
input_, group_name=self.unique_name)
else:
torch.ops.aphrodite.inplace_all_reduce(input_,
group_name=self.unique_name)
return input_

def _all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
"""
The actual all-reduce implementation.
NOTE: This operation will be applied in-place or out-of-place.
Always assume this function modifies its input, but use the return
value as the output.
"""
ca_comm = self.ca_comm

# Bypass the function if we are using only 1 GPU.
if self.world_size == 1:
return input_

# For TPUs, use TPU communicator.
tpu_comm = self.tpu_communicator
if tpu_comm is not None and not tpu_comm.disabled:
Expand Down Expand Up @@ -760,6 +828,7 @@ def init_world_group(ranks: List[int], local_rank: int,
use_pynccl=False,
use_custom_allreduce=False,
use_tpu_communicator=False,
group_name="world",
)


Expand All @@ -769,6 +838,7 @@ def init_model_parallel_group(
backend: str,
use_custom_allreduce: Optional[bool] = None,
use_message_queue_broadcaster: bool = False,
group_name: Optional[str] = None,
) -> GroupCoordinator:
if use_custom_allreduce is None:
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
Expand All @@ -780,6 +850,7 @@ def init_model_parallel_group(
use_custom_allreduce=use_custom_allreduce,
use_tpu_communicator=True,
use_message_queue_broadcaster=use_message_queue_broadcaster,
group_name=group_name,
)


Expand Down Expand Up @@ -935,7 +1006,8 @@ def initialize_model_parallel(
_TP = init_model_parallel_group(group_ranks,
get_world_group().local_rank,
backend,
use_message_queue_broadcaster=True)
use_message_queue_broadcaster=True,
group_name="tp")

# Build the pipeline model-parallel groups.
num_pipeline_model_parallel_groups: int = (world_size //
Expand All @@ -951,7 +1023,8 @@ def initialize_model_parallel(
_PP = init_model_parallel_group(group_ranks,
get_world_group().local_rank,
backend,
use_custom_allreduce=False)
use_custom_allreduce=False,
group_name="pp")


def ensure_model_parallel_initialized(
Expand Down
12 changes: 0 additions & 12 deletions kernels/all_reduce/custom_all_reduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -56,18 +56,6 @@ bool _is_weak_contiguous(torch::Tensor& t) {
t.numel() * t.element_size());
}

bool should_custom_ar(torch::Tensor& inp, int64_t max_size, int64_t world_size,
bool full_nvlink) {
auto inp_size = inp.numel() * inp.element_size();
// custom allreduce requires input byte size to be multiples of 16
if (inp_size % 16 != 0) return false;
if (!_is_weak_contiguous(inp)) return false;
if (world_size == 2 || full_nvlink) return inp_size <= max_size;
// for 4 or more non NVLink-capable GPUs, custom allreduce provides little
// performance improvement over NCCL.
return false;
}

void _all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
cudaStream_t stream) {
auto fa = reinterpret_cast<aphrodite::CustomAllreduce*>(_fa);
Expand Down
2 changes: 0 additions & 2 deletions kernels/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,6 @@ fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data,
const std::vector<std::string>& handles,
const std::vector<int64_t>& offsets, int64_t rank,
bool full_nvlink);
bool should_custom_ar(torch::Tensor& inp, int64_t max_size, int64_t world_size,
bool full_nvlink);
void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out);
void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer,
torch::Tensor& out);
Expand Down
4 changes: 0 additions & 4 deletions kernels/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -512,10 +512,6 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
"str[] handles, int[] offsets, int rank, "
"bool full_nvlink) -> int");
custom_ar.impl("init_custom_ar", torch::kCUDA, &init_custom_ar);
custom_ar.def(
"should_custom_ar(Tensor inp, int max_size, int world_size, "
"bool full_nvlink) -> bool");
custom_ar.impl("should_custom_ar", torch::kCUDA, &should_custom_ar);

custom_ar.def("all_reduce_reg(int fa, Tensor inp, Tensor! out) -> ()");
custom_ar.impl("all_reduce_reg", torch::kCUDA, &all_reduce_reg);
Expand Down
Empty file added tests/compile/__init__.py
Empty file.
13 changes: 11 additions & 2 deletions tests/compile/test_full_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,18 @@

import pytest

from aphrodite.common.utils import cuda_device_count_stateless

from ..utils import fork_new_process_for_each_test


@pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"])
def test_full_graph(model):
@pytest.mark.parametrize("tp_size", [1, 2])
@fork_new_process_for_each_test
def test_full_graph(model, tp_size):
# Skip the test if there are not enough CUDA devices.
if cuda_device_count_stateless() < tp_size:
pytest.skip("Not enough CUDA devices for the test.")
# make sure these models can be captured in full graph mode
os.environ["APHRODITE_TEST_DYNAMO_GRAPH_CAPTURE"] = "1"

Expand All @@ -16,7 +25,7 @@ def test_full_graph(model):
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0)
llm = LLM(model=model, enforce_eager=True)
llm = LLM(model=model, enforce_eager=True, tensor_parallel_size=tp_size)
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
Expand Down

0 comments on commit 239a8ca

Please sign in to comment.