Skip to content

Commit

Permalink
process_group: add inprocess ProcessGroupNCCL
Browse files Browse the repository at this point in the history
  • Loading branch information
d4l3k committed Oct 30, 2024
1 parent f07c80c commit 4bf29d5
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 6 deletions.
46 changes: 40 additions & 6 deletions torchft/process_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,21 +85,30 @@ def getBackendName(self) -> str:
raise NotImplementedError("not implemented")


class ProcessGroupGloo(ProcessGroup):
class ProcessGroupWrapper(ProcessGroup):
PG_CLASS: Type[BaseProcessGroup]
"""
This is a wrapper around ProcessGroupGloo with a reconfiguration argument.
This is a wrapper around any ProcessGroup with a reconfiguration method.
"""

def __init__(self) -> None:
def __init__(self, timeout: float = 60.0) -> None:
"""
Args:
timeout: the timeout to use for the process group
"""
super().__init__(0, 1)
self._pg = None

def configure(self, store_addr: str, rank: int, world_size: int) -> None:
if self._pg is not None:
if hasattr(self._pg, "abort"):
self._pg.abort()
self._pg = None

store = create_store(store_addr)

# TODO: set lower timeout
# pyre-fixme[16]: no attribute ProcessGroupGloo
self._pg = BaseProcessGroupGloo(store, rank, world_size)
# TODO: set global timeout
self._pg = self.PG_CLASS(store, rank, world_size)

def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work:
return self._pg.allreduce(tensors, opts)
Expand All @@ -118,10 +127,35 @@ def broadcast(self, tensor_list: List[torch.Tensor], opts: object) -> Work:
def size(self) -> int:
return self._pg.size()


class ProcessGroupGloo(ProcessGroupWrapper):
"""
This is a reconfigurable version of ProcessGroupGloo.
"""

PG_CLASS = BaseProcessGroupGloo

def getBackendName(self) -> str:
return "torchft-gloo"


class ProcessGroupNCCL(ProcessGroupWrapper):
"""
This is a reconfigurable version of ProcessGroupNCCL.
WARNING: this may result in deadlocks due to NCCL error handling. This is
provided for completeness but your mileage may vary.
TODO: verify shutdown correctness with latest NCCL. This currently will call
abort when reconfiguring, we need to ensure this is safe.
"""

PG_CLASS = BaseProcessGroupNCCL

def getBackendName(self) -> str:
return "torchft-nccl"


class DummyWork(dist._Work):
def __init__(self, result):
super().__init__()
Expand Down
32 changes: 32 additions & 0 deletions torchft/process_group_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
ProcessGroupBabyGloo,
ProcessGroupBabyNCCL,
ProcessGroupGloo,
ProcessGroupNCCL,
ProcessGroupDummy,
ProcessGroup,
)
Expand All @@ -41,6 +42,37 @@ def test_gloo(self) -> None:
m = torch.nn.parallel.DistributedDataParallel(m, process_group=pg)
m(torch.rand(2, 3))

@skipUnless(torch.cuda.is_available(), "needs CUDA")
def test_nccl(self) -> None:
store = TCPStore(
host_name="localhost", port=0, is_master=True, wait_for_workers=False
)
device = "cuda"

store_addr = f"localhost:{store.port}/prefix"
pg = ProcessGroupNCCL()
pg.configure(store_addr, 0, 1)

self.assertEqual(pg.size(), 1)

at = torch.tensor([2], device=device)
a_work = pg.allreduce([at], ReduceOp.SUM)
a_work.wait()

m = nn.Linear(3, 4).to(device)
m = torch.nn.parallel.DistributedDataParallel(m, process_group=pg)
m(torch.rand(2, 3, device=device))

# reconfigure
store_addr = f"localhost:{store.port}/prefix2"
pg.configure(store_addr, 0, 1)

at = torch.tensor([2], device=device)
a_work = pg.allreduce([at], ReduceOp.SUM)
a_work.wait()

torch.cuda.synchronize()

def test_baby_gloo(self) -> None:
store = TCPStore(
host_name="localhost", port=0, is_master=True, wait_for_workers=False
Expand Down

0 comments on commit 4bf29d5

Please sign in to comment.