Skip to content

Commit

Permalink
Add torch.cuda.is_available() condition
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi authored Aug 8, 2024
1 parent c82b2cf commit 7d99a95
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion nerfstudio/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@ def train_loop(local_rank: int, world_size: int, config: TrainerConfig, global_r
config: config file specifying training regimen
"""
_set_random_seed(config.machine.seed + global_rank)
torch.cuda.set_device(local_rank)
if torch.cuda.is_available():
torch.cuda.set_device(local_rank)
trainer = config.setup(local_rank=local_rank, world_size=world_size)
trainer.setup()
trainer.train()
Expand Down

0 comments on commit 7d99a95

Please sign in to comment.