Skip to content

Commit

Permalink
implement distributed training and validation
Browse files Browse the repository at this point in the history
  • Loading branch information
AdeelH committed Jan 4, 2024
1 parent ddac685 commit b60f8ed
Show file tree
Hide file tree
Showing 5 changed files with 563 additions and 106 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import warnings
import logging

import torch.distributed as dist

from rastervision.pytorch_learner.learner import Learner
from rastervision.pytorch_learner.utils import (
compute_conf_mat_metrics, compute_conf_mat, aggregate_metrics)
Expand Down Expand Up @@ -35,6 +37,13 @@ def validate_step(self, batch, batch_ind):
def validate_end(self, outputs):
metrics = aggregate_metrics(outputs, exclude_keys={'conf_mat'})
conf_mat = sum([o['conf_mat'] for o in outputs])

if self.is_ddp_process:
metrics = self.reduce_distributed_metrics(metrics)
dist.reduce(conf_mat, dst=0, op=dist.ReduceOp.SUM)
if not self.is_ddp_master:
return metrics

conf_mat_metrics = compute_conf_mat_metrics(conf_mat,
self.cfg.data.class_names)
metrics.update(conf_mat_metrics)
Expand Down
Loading

0 comments on commit b60f8ed

Please sign in to comment.