Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create a NCCL sub-communicator using ncclCommSplit #2495

Open
wants to merge 8 commits into
base: branch-24.12
Choose a base branch
from
34 changes: 30 additions & 4 deletions cpp/include/raft/comms/detail/mpi_comms.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,31 @@ class mpi_comms : public comms_iface {
RAFT_MPI_TRY(MPI_Bcast((void*)&id, sizeof(id), MPI_BYTE, 0, mpi_comm_));

// initializing NCCL
RAFT_NCCL_TRY(ncclCommInitRank(&nccl_comm_, size_, id, rank_));
ncclConfig_t nccl_config = NCCL_CONFIG_INITIALIZER;
nccl_config.splitShare = 1;
RAFT_NCCL_TRY(ncclCommInitRankConfig(&nccl_comm_, size_, id, rank_, &nccl_config));

initialize();
}

mpi_comms(MPI_Comm mpi_comm,
bool owns_mpi_comm,
ncclComm_t nccl_comm,
rmm::cuda_stream_view stream)
: owns_mpi_comm_(owns_mpi_comm),
mpi_comm_(mpi_comm),
nccl_comm_(nccl_comm),
size_(0),
rank_(1),
status_(stream),
next_request_id_(0),
stream_(stream)
{
int mpi_is_initialized = 0;
RAFT_MPI_TRY(MPI_Initialized(&mpi_is_initialized));
RAFT_EXPECTS(mpi_is_initialized, "ERROR: MPI is not initialized!");
RAFT_MPI_TRY(MPI_Comm_size(mpi_comm_, &size_));
RAFT_MPI_TRY(MPI_Comm_rank(mpi_comm_, &rank_));

initialize();
}
Expand All @@ -150,9 +174,11 @@ class mpi_comms : public comms_iface {

std::unique_ptr<comms_iface> comm_split(int color, int key) const
{
MPI_Comm new_comm;
RAFT_MPI_TRY(MPI_Comm_split(mpi_comm_, color, key, &new_comm));
return std::unique_ptr<comms_iface>(new mpi_comms(new_comm, true, stream_));
MPI_Comm new_mpi_comm;
RAFT_MPI_TRY(MPI_Comm_split(mpi_comm_, color, key, &new_mpi_comm));
ncclComm_t new_nccl_comm{};
RAFT_NCCL_TRY(ncclCommSplit(nccl_comm_, color, key, &new_nccl_comm, nullptr));
return std::unique_ptr<comms_iface>(new mpi_comms(new_mpi_comm, true, new_nccl_comm, stream_));
}

void barrier() const
Expand Down
55 changes: 6 additions & 49 deletions cpp/include/raft/comms/detail/std_comms.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,55 +140,12 @@ class std_comms : public comms_iface {

std::unique_ptr<comms_iface> comm_split(int color, int key) const
{
rmm::device_uvector<int> d_colors(get_size(), stream_);
rmm::device_uvector<int> d_keys(get_size(), stream_);

update_device(d_colors.data() + get_rank(), &color, 1, stream_);
update_device(d_keys.data() + get_rank(), &key, 1, stream_);

allgather(d_colors.data() + get_rank(), d_colors.data(), 1, datatype_t::INT32, stream_);
allgather(d_keys.data() + get_rank(), d_keys.data(), 1, datatype_t::INT32, stream_);
this->sync_stream(stream_);

std::vector<int> h_colors(get_size());
std::vector<int> h_keys(get_size());

update_host(h_colors.data(), d_colors.data(), get_size(), stream_);
update_host(h_keys.data(), d_keys.data(), get_size(), stream_);

this->sync_stream(stream_);

ncclComm_t nccl_comm;

// Create a structure to allgather...
ncclUniqueId id{};
rmm::device_uvector<ncclUniqueId> d_nccl_ids(get_size(), stream_);

if (key == 0) { RAFT_NCCL_TRY(ncclGetUniqueId(&id)); }

update_device(d_nccl_ids.data() + get_rank(), &id, 1, stream_);

allgather(d_nccl_ids.data() + get_rank(),
d_nccl_ids.data(),
sizeof(ncclUniqueId),
datatype_t::UINT8,
stream_);

auto offset =
std::distance(thrust::make_zip_iterator(h_colors.begin(), h_keys.begin()),
std::find_if(thrust::make_zip_iterator(h_colors.begin(), h_keys.begin()),
thrust::make_zip_iterator(h_colors.end(), h_keys.end()),
[color](auto tuple) { return thrust::get<0>(tuple) == color; }));

auto subcomm_size = std::count(h_colors.begin(), h_colors.end(), color);

update_host(&id, d_nccl_ids.data() + offset, 1, stream_);

this->sync_stream(stream_);

RAFT_NCCL_TRY(ncclCommInitRank(&nccl_comm, subcomm_size, id, key));

return std::unique_ptr<comms_iface>(new std_comms(nccl_comm, subcomm_size, key, stream_, true));
ncclComm_t new_nccl_comm{};
RAFT_NCCL_TRY(ncclCommSplit(nccl_comm_, color, key, &new_nccl_comm, nullptr));
int new_nccl_comm_size{};
RAFT_NCCL_TRY(ncclCommCount(new_nccl_comm, &new_nccl_comm_size));
return std::unique_ptr<comms_iface>(
new std_comms(new_nccl_comm, new_nccl_comm_size, key, stream_, true));
}

void barrier() const
Expand Down
Loading