Skip to content

Commit

Permalink
wip [skip ci]
Browse files Browse the repository at this point in the history
  • Loading branch information
AdeelH committed Jan 2, 2024
1 parent c0597b1 commit 355b299
Showing 1 changed file with 44 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,7 @@ def _train_distributed(self, epochs: Optional[int] = None):
self.on_train_start()

train_dl = self.build_dataloader('train')
val_dl = self.build_dataloader('valid')
for epoch in range(start_epoch, end_epoch):
log.info(f'epoch: {epoch}')

Expand All @@ -474,23 +475,19 @@ def _train_distributed(self, epochs: Optional[int] = None):
optimizer=self.opt,
step_scheduler=self.step_scheduler,
dataloader=train_dl)
train_metrics = self.reduce_distributed_metrics(train_metrics)

if self.epoch_scheduler:
self.epoch_scheduler.step()

for k in train_metrics.keys():
v = train_metrics[k]
if isinstance(v, Tensor):
v = dist.reduce(v, dst=0, op=dist.ReduceOp.SUM)
if self.is_ddp_master:
train_metrics[k] = v / self.ddp_world_size
valid_metrics = self.validate_epoch(val_dl)
valid_metrics = self.reduce_distributed_metrics(valid_metrics)

if self.is_ddp_master:
valid_metrics = self.validate_epoch(self.valid_dl)
metrics = dict(epoch=epoch, **train_metrics, **valid_metrics)
log.info(f'metrics:\n{pformat(metrics, sort_dicts=False)}')
self.on_epoch_end(epoch, metrics)

if self.epoch_scheduler:
self.epoch_scheduler.step()

dist.barrier()

def _run_train_distributed(self, rank: int, world_size: int, *args):
Expand Down Expand Up @@ -538,7 +535,11 @@ def train_epoch(
dataloader = self.train_dl
start = time.time()
outputs = []
with tqdm(self.train_dl, desc='Training') as bar:
if self.ddp_rank is not None:
desc = f'Training (worker={self.ddp_rank})'
else:
desc = 'Training'
with tqdm(self.train_dl, desc=desc) as bar:
for batch_ind, (x, y) in enumerate(bar):
x = self.to_device(x, self.device)
y = self.to_device(y, self.device)
Expand Down Expand Up @@ -606,8 +607,12 @@ def validate_epoch(self, dl: DataLoader) -> MetricDict:
start = time.time()
self.model.eval()
outputs = []
if self.ddp_rank is not None:
desc = f'Validating (worker={self.ddp_rank})'
else:
desc = 'Validating'
with torch.inference_mode():
with tqdm(dl, desc='Validating') as bar:
with tqdm(dl, desc=desc) as bar:
for batch_ind, (x, y) in enumerate(bar):
x = self.to_device(x, self.device)
y = self.to_device(y, self.device)
Expand Down Expand Up @@ -1184,14 +1189,22 @@ def build_samplers(

def build_sampler(self, split: str) -> Optional['Sampler']:
"""Return an optional sampler for the split's dataloader."""
split = split.lower()
sampler = None
if split.lower() == 'train':
if split == 'train':
if self.distributed:
sampler = DistributedSampler(
self.train_ds,
shuffle=True,
num_replicas=self.ddp_world_size,
rank=self.ddp_rank)
elif split == 'valid':
if self.distributed:
sampler = DistributedSampler(
self.valid_ds,
shuffle=False,
num_replicas=self.ddp_world_size,
rank=self.ddp_rank)
return sampler

def setup_loss(self, loss_def_path: Optional[str] = None) -> None:
Expand Down Expand Up @@ -1462,6 +1475,24 @@ def _bundle_transforms(self, model_bundle_dir: str) -> None:
#########
# Misc.
#########
def reduce_distributed_metrics(self, metrics: dict):
log.info(
f'{self.ddp_rank}: metrics:\n{pformat(metrics, sort_dicts=False)}')
dist.barrier()
for k in metrics.keys():
v = metrics[k]
if isinstance(v, (float, int)):
print('tensoring', k, v, flush=True)
v = torch.tensor(v, device=self.device)
if isinstance(v, Tensor):
print(k, v, flush=True)
v = dist.all_reduce(v, op=dist.ReduceOp.SUM)
print('recvd', self.ddp_rank, k, v, flush=True)
assert v is not None
if self.is_ddp_master:
metrics[k] = v / self.ddp_world_size
return metrics

def post_forward(self, x: Any) -> Any:
"""Post process output of call to model().
Expand Down

0 comments on commit 355b299

Please sign in to comment.