diff --git a/torchft/process_group.py b/torchft/process_group.py index d90c50e..52aa0d1 100644 --- a/torchft/process_group.py +++ b/torchft/process_group.py @@ -124,7 +124,11 @@ def create_pg( ) -> ProcessGroup: return self - dist.Backend.register_backend(group_name, create_pg) + if torch.cuda.is_available(): + devices = ["cuda", "cpu"] + else: + devices = ["cpu"] + dist.Backend.register_backend(group_name, create_pg, devices=devices) return dist.new_group( ranks=[dist.get_rank()],