From 355b29935485b00cebb4725a7bd3625453b67aee Mon Sep 17 00:00:00 2001 From: Adeel Hassan Date: Tue, 2 Jan 2024 10:26:51 -0500 Subject: [PATCH] wip [skip ci] --- .../rastervision/pytorch_learner/learner.py | 57 ++++++++++++++----- 1 file changed, 44 insertions(+), 13 deletions(-) diff --git a/rastervision_pytorch_learner/rastervision/pytorch_learner/learner.py b/rastervision_pytorch_learner/rastervision/pytorch_learner/learner.py index c033d16a87..c7d4d0c9a8 100644 --- a/rastervision_pytorch_learner/rastervision/pytorch_learner/learner.py +++ b/rastervision_pytorch_learner/rastervision/pytorch_learner/learner.py @@ -465,6 +465,7 @@ def _train_distributed(self, epochs: Optional[int] = None): self.on_train_start() train_dl = self.build_dataloader('train') + val_dl = self.build_dataloader('valid') for epoch in range(start_epoch, end_epoch): log.info(f'epoch: {epoch}') @@ -474,23 +475,19 @@ def _train_distributed(self, epochs: Optional[int] = None): optimizer=self.opt, step_scheduler=self.step_scheduler, dataloader=train_dl) + train_metrics = self.reduce_distributed_metrics(train_metrics) - if self.epoch_scheduler: - self.epoch_scheduler.step() - - for k in train_metrics.keys(): - v = train_metrics[k] - if isinstance(v, Tensor): - v = dist.reduce(v, dst=0, op=dist.ReduceOp.SUM) - if self.is_ddp_master: - train_metrics[k] = v / self.ddp_world_size + valid_metrics = self.validate_epoch(val_dl) + valid_metrics = self.reduce_distributed_metrics(valid_metrics) if self.is_ddp_master: - valid_metrics = self.validate_epoch(self.valid_dl) metrics = dict(epoch=epoch, **train_metrics, **valid_metrics) log.info(f'metrics:\n{pformat(metrics, sort_dicts=False)}') self.on_epoch_end(epoch, metrics) + if self.epoch_scheduler: + self.epoch_scheduler.step() + dist.barrier() def _run_train_distributed(self, rank: int, world_size: int, *args): @@ -538,7 +535,11 @@ def train_epoch( dataloader = self.train_dl start = time.time() outputs = [] - with tqdm(self.train_dl, desc='Training') as bar: + if self.ddp_rank is not None: + desc = f'Training (worker={self.ddp_rank})' + else: + desc = 'Training' + with tqdm(self.train_dl, desc=desc) as bar: for batch_ind, (x, y) in enumerate(bar): x = self.to_device(x, self.device) y = self.to_device(y, self.device) @@ -606,8 +607,12 @@ def validate_epoch(self, dl: DataLoader) -> MetricDict: start = time.time() self.model.eval() outputs = [] + if self.ddp_rank is not None: + desc = f'Validating (worker={self.ddp_rank})' + else: + desc = 'Validating' with torch.inference_mode(): - with tqdm(dl, desc='Validating') as bar: + with tqdm(dl, desc=desc) as bar: for batch_ind, (x, y) in enumerate(bar): x = self.to_device(x, self.device) y = self.to_device(y, self.device) @@ -1184,14 +1189,22 @@ def build_samplers( def build_sampler(self, split: str) -> Optional['Sampler']: """Return an optional sampler for the split's dataloader.""" + split = split.lower() sampler = None - if split.lower() == 'train': + if split == 'train': if self.distributed: sampler = DistributedSampler( self.train_ds, shuffle=True, num_replicas=self.ddp_world_size, rank=self.ddp_rank) + elif split == 'valid': + if self.distributed: + sampler = DistributedSampler( + self.valid_ds, + shuffle=False, + num_replicas=self.ddp_world_size, + rank=self.ddp_rank) return sampler def setup_loss(self, loss_def_path: Optional[str] = None) -> None: @@ -1462,6 +1475,24 @@ def _bundle_transforms(self, model_bundle_dir: str) -> None: ######### # Misc. ######### + def reduce_distributed_metrics(self, metrics: dict): + log.info( + f'{self.ddp_rank}: metrics:\n{pformat(metrics, sort_dicts=False)}') + dist.barrier() + for k in metrics.keys(): + v = metrics[k] + if isinstance(v, (float, int)): + print('tensoring', k, v, flush=True) + v = torch.tensor(v, device=self.device) + if isinstance(v, Tensor): + print(k, v, flush=True) + v = dist.all_reduce(v, op=dist.ReduceOp.SUM) + print('recvd', self.ddp_rank, k, v, flush=True) + assert v is not None + if self.is_ddp_master: + metrics[k] = v / self.ddp_world_size + return metrics + def post_forward(self, x: Any) -> Any: """Post process output of call to model().