From c461da614cf50457f5e6da374b2940275429f792 Mon Sep 17 00:00:00 2001 From: Tristan Rice Date: Tue, 29 Oct 2024 22:51:10 -0700 Subject: [PATCH] process_group: add inprocess ProcessGroupNCCL --- torchft/process_group.py | 40 ++++++++++++++++++++++++++++++----- torchft/process_group_test.py | 32 ++++++++++++++++++++++++++++ 2 files changed, 67 insertions(+), 5 deletions(-) diff --git a/torchft/process_group.py b/torchft/process_group.py index 7d4c330..494c687 100644 --- a/torchft/process_group.py +++ b/torchft/process_group.py @@ -85,9 +85,10 @@ def getBackendName(self) -> str: raise NotImplementedError("not implemented") -class ProcessGroupGloo(ProcessGroup): +class ProcessGroupWrapper(ProcessGroup): + PG_CLASS: Type[BaseProcessGroup] """ - This is a wrapper around ProcessGroupGloo with a reconfiguration argument. + This is a wrapper around any ProcessGroup with a reconfiguration method. """ def __init__(self) -> None: @@ -95,11 +96,15 @@ def __init__(self) -> None: self._pg = None def configure(self, store_addr: str, rank: int, world_size: int) -> None: + if self._pg is not None: + if hasattr(self._pg, "abort"): + self._pg.abort() + self._pg = None + store = create_store(store_addr) - # TODO: set lower timeout - # pyre-fixme[16]: no attribute ProcessGroupGloo - self._pg = BaseProcessGroupGloo(store, rank, world_size) + # TODO: set global timeout + self._pg = self.PG_CLASS(store, rank, world_size) def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work: return self._pg.allreduce(tensors, opts) @@ -118,10 +123,35 @@ def broadcast(self, tensor_list: List[torch.Tensor], opts: object) -> Work: def size(self) -> int: return self._pg.size() + +class ProcessGroupGloo(ProcessGroupWrapper): + """ + This is a reconfigurable version of ProcessGroupGloo. + """ + + PG_CLASS = BaseProcessGroupGloo + def getBackendName(self) -> str: return "torchft-gloo" +class ProcessGroupNCCL(ProcessGroupWrapper): + """ + This is a reconfigurable version of ProcessGroupNCCL. + + WARNING: this may result in deadlocks due to NCCL error handling. This is + provided for completeness but your mileage may vary. + + TODO: verify shutdown correctness with latest NCCL. This currently will call + abort when reconfiguring, we need to ensure this is safe. + """ + + PG_CLASS = BaseProcessGroupNCCL + + def getBackendName(self) -> str: + return "torchft-nccl" + + class DummyWork(dist._Work): def __init__(self, result): super().__init__() diff --git a/torchft/process_group_test.py b/torchft/process_group_test.py index 3b3aafd..f741b7f 100644 --- a/torchft/process_group_test.py +++ b/torchft/process_group_test.py @@ -15,6 +15,7 @@ ProcessGroupBabyGloo, ProcessGroupBabyNCCL, ProcessGroupGloo, + ProcessGroupNCCL, ProcessGroupDummy, ProcessGroup, ) @@ -41,6 +42,37 @@ def test_gloo(self) -> None: m = torch.nn.parallel.DistributedDataParallel(m, process_group=pg) m(torch.rand(2, 3)) + @skipUnless(torch.cuda.is_available(), "needs CUDA") + def test_nccl(self) -> None: + store = TCPStore( + host_name="localhost", port=0, is_master=True, wait_for_workers=False + ) + device = "cuda" + + store_addr = f"localhost:{store.port}/prefix" + pg = ProcessGroupNCCL() + pg.configure(store_addr, 0, 1) + + self.assertEqual(pg.size(), 1) + + at = torch.tensor([2], device=device) + a_work = pg.allreduce([at], ReduceOp.SUM) + a_work.wait() + + m = nn.Linear(3, 4).to(device) + m = torch.nn.parallel.DistributedDataParallel(m, process_group=pg) + m(torch.rand(2, 3, device=device)) + + # reconfigure + store_addr = f"localhost:{store.port}/prefix2" + pg.configure(store_addr, 0, 1) + + at = torch.tensor([2], device=device) + a_work = pg.allreduce([at], ReduceOp.SUM) + a_work.wait() + + torch.cuda.synchronize() + def test_baby_gloo(self) -> None: store = TCPStore( host_name="localhost", port=0, is_master=True, wait_for_workers=False