diff --git a/rastervision_pytorch_learner/rastervision/pytorch_learner/classification_learner.py b/rastervision_pytorch_learner/rastervision/pytorch_learner/classification_learner.py index c0eb047dff..17427e0f99 100644 --- a/rastervision_pytorch_learner/rastervision/pytorch_learner/classification_learner.py +++ b/rastervision_pytorch_learner/rastervision/pytorch_learner/classification_learner.py @@ -1,6 +1,8 @@ import warnings import logging +import torch.distributed as dist + from rastervision.pytorch_learner.learner import Learner from rastervision.pytorch_learner.utils import ( compute_conf_mat_metrics, compute_conf_mat, aggregate_metrics) @@ -35,6 +37,13 @@ def validate_step(self, batch, batch_ind): def validate_end(self, outputs): metrics = aggregate_metrics(outputs, exclude_keys={'conf_mat'}) conf_mat = sum([o['conf_mat'] for o in outputs]) + + if self.is_ddp_process: + metrics = self.reduce_distributed_metrics(metrics) + dist.reduce(conf_mat, dst=0, op=dist.ReduceOp.SUM) + if not self.is_ddp_master: + return metrics + conf_mat_metrics = compute_conf_mat_metrics(conf_mat, self.cfg.data.class_names) metrics.update(conf_mat_metrics) diff --git a/rastervision_pytorch_learner/rastervision/pytorch_learner/learner.py b/rastervision_pytorch_learner/rastervision/pytorch_learner/learner.py index 108564966f..b7050f34bf 100644 --- a/rastervision_pytorch_learner/rastervision/pytorch_learner/learner.py +++ b/rastervision_pytorch_learner/rastervision/pytorch_learner/learner.py @@ -1,6 +1,7 @@ from typing import (TYPE_CHECKING, Any, Callable, Dict, Iterator, List, Literal, Optional, Tuple, Union, Type) from abc import ABC, abstractmethod +import os from os.path import join, isfile, basename, isdir import warnings import time @@ -18,9 +19,13 @@ from torch import Tensor import torch.nn as nn from torch.utils.tensorboard import SummaryWriter -from torch.utils.data import DataLoader +from torch.utils.data import DataLoader, DistributedSampler +from torch.nn.parallel import DistributedDataParallel as DDP +import torch.distributed as dist +import torch.multiprocessing as mp from rastervision.pipeline import rv_config_ as rv_config +from rastervision.pipeline.utils import get_env_var from rastervision.pipeline.file_system import ( sync_to_dir, json_to_file, file_to_json, make_dir, zipdir, download_if_needed, download_or_copy, sync_from_dir, get_local_path, unzip, @@ -49,6 +54,8 @@ TRANSFORMS_DIRNAME = 'custom_albumentations_transforms' BUNDLE_MODEL_WEIGHTS_FILENAME = 'model.pth' BUNDLE_MODEL_ONNX_FILENAME = 'model.onnx' +DDP_BACKEND = rv_config.get_namespace_option('rastervision', 'DDP_BACKEND', + 'nccl') log = logging.getLogger(__name__) @@ -63,13 +70,25 @@ class Learner(ABC): The datasets, model, optimizer, and schedulers will be generated from the cfg if not specified in the constructor. - If instantiated with training=False, the training apparatus (loss, + If instantiated with `training=False`, the training apparatus (loss, optimizer, scheduler, logging, etc.) will not be set up and the model will be put into eval mode. - Note that various training and prediction methods have the side effect of - putting Learner.model into training or eval mode. No attempt is made to put the - model back into the mode it was previously in. + .. note:: + + This class supports distributed training via PyTorch DDP. If + instantiated as a DDP process, it will automatically read WORLD_SIZE, + RANK, and LOCAL_RANK environment variables. Alternatively, if + ``RASTERVISION_USE_DDP=YES`` (the default), and multiple GPUs are + detected, it will spawn DDP processes itself (one per GPU) when + training. DDP options that may be set via environment variables or an + INI file (see :ref:`raster vision config`) are: + + - ``RASTERVISION_USE_DDP``: Use DDP? Default: ``YES``. + - ``RASTERVISION_DDP_BACKEND``: Default: ``nccl``. + - ``RASTERVISION_DDP_START_METHOD``: One of ``spawn``, ``fork``, or + ``forkserver``. Default: ``spawn``. + """ def __init__(self, @@ -145,14 +164,7 @@ def __init__(self, tmp_dir = self._tmp_dir.name self.tmp_dir = tmp_dir - if torch.cuda.is_available(): - device = 'cuda' - elif torch.backends.mps.is_available(): - device = 'mps' - else: - device = 'cpu' - - self.device = torch.device(device) + self.training = training self.train_ds = train_ds self.valid_ds = valid_ds @@ -172,6 +184,20 @@ def __init__(self, self.tb_writer = None self.tb_log_dir = None + self.setup_ddp_params() + + if self.avoid_activating_cuda_runtime: + device = 'cuda' + else: + if torch.cuda.is_available(): + device = 'cuda' + elif torch.backends.mps.is_available(): + device = 'mps' + else: + device = 'cpu' + + self.device = torch.device(device) + # --------------------------- # Set URIs # --------------------------- @@ -195,7 +221,7 @@ def __init__(self, else: self.output_dir_local = get_local_path(self.output_dir, tmp_dir) make_dir(self.output_dir_local, force_empty=True) - if training and not cfg.overfit_mode: + if self.training and not cfg.overfit_mode: self.sync_from_cloud() log.info(f'Local output dir: {self.output_dir_local}') log.info(f'Remote output dir: {self.output_dir}') @@ -207,13 +233,19 @@ def __init__(self, # --------------------------- self._onnx_mode = False - self.setup_model( - model_weights_path=model_weights_path, - model_def_path=model_def_path) + self.init_model_weights_path = model_weights_path + self.init_model_def_path = model_def_path + self.init_loss_def_path = loss_def_path + + if not self.distributed: + self.setup_model( + model_weights_path=model_weights_path, + model_def_path=model_def_path) - if training: + if self.training: self.setup_training(loss_def_path=loss_def_path) - self.model.train() + if self.model is not None: + self.model.train() else: if not self.onnx_mode: self.model.eval() @@ -362,7 +394,8 @@ def main(self): resume if interrupted), logs stats, plots predictions, and syncs results to the cloud. """ - log_system_details() + if not self.avoid_activating_cuda_runtime: + log_system_details() log.info(self.cfg) log.info(f'Using device: {self.device}') self.log_data_stats() @@ -370,7 +403,8 @@ def main(self): cfg = self.cfg if not cfg.predict_mode: - self.plot_dataloaders(self.cfg.data.preview_batch_limit) + if not self.avoid_activating_cuda_runtime: + self.plot_dataloaders(self.cfg.data.preview_batch_limit) if cfg.overfit_mode: self.overfit() else: @@ -382,14 +416,30 @@ def main(self): self.stop_tensorboard() if cfg.eval_train: - self.eval_model('train') - self.eval_model('valid') + self.validate('train') + self.validate('valid') self.sync_to_cloud() ########################### # Training and validation ########################### def train(self, epochs: Optional[int] = None): + """Run training loop.""" + if self.is_ddp_process: + self._run_train_distributed(self.ddp_rank, self.ddp_world_size, + epochs) + elif self.distributed: + log.info('Spawning %d DDP processes', self.ddp_world_size) + mp.start_processes( + self._run_train_distributed, + args=(self.ddp_world_size, epochs), + nprocs=self.ddp_world_size, + join=True, + start_method=self.ddp_start_method) + else: + self._train(epochs) + + def _train(self, epochs: Optional[int] = None): """Training loop that will attempt to resume training if appropriate.""" start_epoch = self.get_start_epoch() @@ -404,25 +454,112 @@ def train(self, epochs: Optional[int] = None): self.on_train_start() for epoch in range(start_epoch, end_epoch): log.info(f'epoch: {epoch}') + train_metrics = self.train_epoch( optimizer=self.opt, step_scheduler=self.step_scheduler) + if self.epoch_scheduler: self.epoch_scheduler.step() + valid_metrics = self.validate_epoch(self.valid_dl) + metrics = dict(epoch=epoch, **train_metrics, **valid_metrics) log.info(f'metrics:\n{pformat(metrics, sort_dicts=False)}') self.on_epoch_end(epoch, metrics) + def _train_distributed(self, epochs: Optional[int] = None): + """Training loop that will attempt to resume training if appropriate.""" + start_epoch = self.get_start_epoch() + + if epochs is None: + end_epoch = self.cfg.solver.num_epochs + else: + end_epoch = start_epoch + epochs + + if (start_epoch > 0 and start_epoch < end_epoch): + log.info(f'Resuming training from epoch {start_epoch}') + + if self.is_ddp_master: + self.on_train_start() + + train_dl = self.build_dataloader('train', distributed=True) + val_dl = self.build_dataloader('valid', distributed=True) + for epoch in range(start_epoch, end_epoch): + log.info(f'epoch: {epoch}') + + train_dl.sampler.set_epoch(epoch) + + train_metrics = self.train_epoch( + optimizer=self.opt, + step_scheduler=self.step_scheduler, + dataloader=train_dl) + + valid_metrics = self.validate_epoch(val_dl) + + if self.is_ddp_master: + metrics = dict(epoch=epoch, **train_metrics, **valid_metrics) + log.info(f'metrics:\n{pformat(metrics, sort_dicts=False)}') + self.on_epoch_end(epoch, metrics) + + if self.epoch_scheduler: + self.epoch_scheduler.step() + + dist.barrier() + + def _run_train_distributed(self, rank: int, world_size: int, *args): + """Method executed by each DDP worker.""" + + os.environ['MASTER_ADDR'] = os.environ.get('MASTER_ADDR', 'localhost') + os.environ['MASTER_PORT'] = os.environ.get('MASTER_PORT', '12355') + + dist.init_process_group(DDP_BACKEND, rank=rank, world_size=world_size) + + if self.ddp_rank is None: + self.ddp_rank = rank + if self.ddp_local_rank is None: + # Implies process was spawned by self.train(), and therefore, + # this is necessarily a single-node multi-GPU scenario. + # So global rank == local rank. + self.ddp_local_rank = rank + if self.ddp_world_size is None: + self.ddp_world_size = world_size + + log.info('DDP rank: %d, DDP local rank: %d', self.ddp_rank, + self.ddp_local_rank) + + self.is_ddp_process = True + self.is_ddp_master = self.ddp_rank == 0 + if self.device.index is None: + self.device = torch.device(self.device.type, self.ddp_local_rank) + torch.cuda.set_device(self.device) + + self.setup_model( + model_weights_path=self.init_model_weights_path, + model_def_path=self.init_model_def_path) + self.setup_training() + dist.barrier() + + self._train_distributed(*args) + + dist.destroy_process_group() + def train_epoch( self, optimizer: 'Optimizer', + dataloader: Optional[DataLoader] = None, step_scheduler: Optional['_LRScheduler'] = None) -> MetricDict: """Train for a single epoch.""" - start = time.time() self.model.train() + if dataloader is None: + dataloader = self.train_dl + start = time.time() outputs = [] - with tqdm(self.train_dl, desc='Training') as bar: + if self.ddp_rank is not None: + desc = f'Training (GPU={self.ddp_rank})' + else: + desc = 'Training' + with tqdm(self.train_dl, desc=desc) as bar: for batch_ind, (x, y) in enumerate(bar): x = self.to_device(x, self.device) y = self.to_device(y, self.device) @@ -467,10 +604,29 @@ def train_end(self, outputs: List[Dict[str, Union[float, Tensor]]] Args: outputs: a list of outputs of train_step """ - return aggregate_metrics(outputs) + metrics = aggregate_metrics(outputs) + if self.is_ddp_process: + metrics = self.reduce_distributed_metrics(metrics) + return metrics + + def validate(self, split: Literal['train', 'valid', 'test'] = 'valid'): + """Evaluate model on a particular data split.""" + if self.is_ddp_process: + self._run_validate_distributed(self.ddp_rank, self.ddp_world_size, + split) + elif self.distributed: + log.info('Spawning DDP processes') + mp.start_processes( + self._run_validate_distributed, + args=(self.ddp_world_size, split), + nprocs=self.ddp_world_size, + join=True, + start_method=self.ddp_start_method) + else: + self._validate(split) - def eval_model(self, split: str): - """Evaluate model using a particular dataset split. + def _validate(self, split: Literal['train', 'valid', 'test'] = 'valid'): + """Evaluate model on a particular data split. Gets validation metrics and saves them along with prediction plots. @@ -479,19 +635,62 @@ def eval_model(self, split: str): """ log.info(f'Evaluating on {split} set...') dl = self.get_dataloader(split) + if dl is None: + self.setup_data() + dl = self.get_dataloader(split) metrics = self.validate_epoch(dl) + if self.is_ddp_process and not self.is_ddp_master: + return log.info(f'metrics: {metrics}') json_to_file(metrics, join(self.output_dir_local, f'{split}_metrics.json')) self.plot_predictions(split, self.cfg.data.preview_batch_limit) + def _run_validate_distributed(self, rank: int, world_size: int, *args): + """Method executed by each DDP worker.""" + + os.environ['MASTER_ADDR'] = os.environ.get('MASTER_ADDR', 'localhost') + os.environ['MASTER_PORT'] = os.environ.get('MASTER_PORT', '12355') + + dist.init_process_group(DDP_BACKEND, rank=rank, world_size=world_size) + + if self.ddp_rank is None: + self.ddp_rank = rank + if self.ddp_local_rank is None: + self.ddp_local_rank = rank + if self.ddp_world_size is None: + self.ddp_world_size = world_size + + log.info('DDP rank: %d, DDP local rank: %d', self.ddp_rank, + self.ddp_local_rank) + + self.is_ddp_process = True + self.is_ddp_master = self.ddp_rank == 0 + if self.device.index is None: + self.device = torch.device(self.device.type, self.ddp_local_rank) + torch.cuda.set_device(self.device) + + self.setup_model( + model_weights_path=self.init_model_weights_path, + model_def_path=self.init_model_def_path) + self.setup_loss(self.init_loss_def_path) + dist.barrier() + + self._validate(*args) + + dist.destroy_process_group() + def validate_epoch(self, dl: DataLoader) -> MetricDict: """Validate for a single epoch.""" start = time.time() self.model.eval() outputs = [] + if self.ddp_rank is not None: + desc = f'Validating (GPU={self.ddp_rank})' + else: + desc = 'Validating' with torch.inference_mode(): - with tqdm(dl, desc='Validating') as bar: + with tqdm(dl, desc=desc) as bar: for batch_ind, (x, y) in enumerate(bar): x = self.to_device(x, self.device) y = self.to_device(y, self.device) @@ -524,7 +723,10 @@ def validate_end(self, outputs: List[Dict[str, Union[float, Tensor]]] Args: outputs: a list of outputs of validate_step """ - return aggregate_metrics(outputs) + metrics = aggregate_metrics(outputs) + if self.is_ddp_process: + metrics = self.reduce_distributed_metrics(metrics) + return metrics def on_epoch_end(self, curr_epoch: int, metrics: MetricDict) -> None: """Hook that is called at end of epoch. @@ -544,7 +746,7 @@ def on_epoch_end(self, curr_epoch: int, metrics: MetricDict) -> None: checkpoint_path = join(self.checkpoints_dir_local, checkpoint_name) shutil.move(self.last_model_weights_path, checkpoint_path) - torch.save(self.model.state_dict(), self.last_model_weights_path) + self.save_weights(self.last_model_weights_path) if (curr_epoch + 1) % self.cfg.solver.sync_interval == 0: self.sync_to_cloud() @@ -569,7 +771,7 @@ def overfit(self): log.info('\nstep: %d', step) log.info('train_loss: %f', loss) - torch.save(self.model.state_dict(), self.last_model_weights_path) + self.save_weights(self.last_model_weights_path) def on_overfit_start(self): """Hook that is called at start of overfit routine.""" @@ -808,33 +1010,140 @@ def prob_to_pred(self, x: Tensor) -> Tensor: ######### # Setup ######### + def setup_ddp_params(self): + """Set up and validate params related to PyTorch DDP.""" + + ddp_allowed = rv_config.get_namespace_option( + 'rastervision', 'USE_DDP', True, as_bool=True) + self.ddp_start_method = rv_config.get_namespace_option( + 'rastervision', 'DDP_START_METHOD', 'spawn').lower() + + self.is_ddp_process = False + self.is_ddp_master = False + self.avoid_activating_cuda_runtime = False + + self.ddp_world_size = get_env_var('WORLD_SIZE', None, int) + self.ddp_rank = get_env_var('RANK', None, int) + self.ddp_local_rank = get_env_var('LOCAL_RANK', None, int) + + if dist.is_initialized(): + if not ddp_allowed: + log.info('Ignoring RASTERVISION_USE_DDP since DDP is already ' + 'initialized.') + ddp_vars_set = all( + [self.ddp_world_size, self.ddp_rank, self.ddp_local_rank]) + if not ddp_vars_set: + raise ValueError( + 'Is DDP process but WORLD_SIZE, RANK, and LOCAL_RANK ' + 'env variables not set.') + self.distributed = True + self.is_ddp_process = True + self.is_ddp_master = self.ddp_rank == 0 + elif not ddp_allowed: + self.distributed = False + elif self.ddp_start_method != 'spawn': + # If ddp_start_method is "fork" or "forkserver", the CUDA runtime + # must not be initialized before the fork; otherwise, a + # "RuntimeError: Cannot re-initialize CUDA in forked subprocess." + # error will be raised. We can avoid initializing it by not + # calling and torch.cuda functions or creating tensors on the GPU. + if self.ddp_world_size is None: + raise ValueError( + 'WORLD_SIZE env variable must be specified if ' + 'RASTERVISION_DDP_START_METHOD is not "spawn".') + self.distributed = True + self.avoid_activating_cuda_runtime = True + elif torch.cuda.is_available(): + dist_available = dist.is_available() + gpu_count = torch.cuda.device_count() + multi_gpus = gpu_count > 1 + self.distributed = ddp_allowed and dist_available and multi_gpus + if self.distributed: + log.info( + 'Multiple GPUs detected (%d), will use DDP for training.', + gpu_count) + world_size_is_set = self.ddp_world_size is not None + if not world_size_is_set: + self.ddp_world_size = gpu_count + if world_size_is_set and self.ddp_world_size < gpu_count: + log.info('Using only WORLD_SIZE=%d of total %d GPUs.', + self.ddp_world_size, gpu_count) + else: + self.distributed = False + + if not self.distributed: + return + + self.ddp_world_size: int + + if self.model is not None: + raise ValueError( + 'In distributed mode, the model must be specified via ' + 'ModelConfig in LearnerConfig rather than be passed ' + 'as an instantiated object.') + dses_passed = any([self.train_ds, self.valid_ds, self.test_ds]) + if dses_passed and self.ddp_start_method != 'fork': + raise ValueError( + 'In distributed mode, if ' + 'RASTERVISION_DDP_START_METHOD != "fork", datasets must be ' + 'specified via DataConfig in LearnerConfig rather than be ' + 'passed as instantiated objects.') + if self.ddp_local_rank is not None: + self.device = torch.device('cuda', self.ddp_local_rank) + + log.info('Using DDP') + log.info(f'World size: {self.ddp_world_size}') + log.info(f'DDP start method: {self.ddp_start_method}') + if self.is_ddp_process: + log.info(f'DDP rank: {self.ddp_rank}') + log.info(f'DDP local rank: {self.ddp_local_rank}') + def setup_training(self, loss_def_path: Optional[str] = None) -> None: cfg = self.cfg self.config_path = join(self.output_dir, 'learner-config.json') str_to_file(cfg.json(), self.config_path) - self.log_path = join(self.output_dir_local, 'log.csv') - - # data - self.setup_data() - - # model self.last_model_weights_path = join(self.output_dir_local, 'last-model.pth') - self.load_checkpoint() - # optimization - start_epoch = self.get_start_epoch() - self.setup_loss(loss_def_path=loss_def_path) - if self.opt is None: - self.opt = self.build_optimizer() - if self.step_scheduler is None: - self.step_scheduler = self.build_step_scheduler(start_epoch) - if self.epoch_scheduler is None: - self.epoch_scheduler = self.build_epoch_scheduler(start_epoch) - - self.setup_tensorboard() + if self.is_ddp_process: + # model + if self.model is not None: + self.load_checkpoint() + # data + self.setup_data() + # optimization + start_epoch = self.get_start_epoch() + self.setup_loss(loss_def_path=loss_def_path) + if self.opt is None: + self.opt = self.build_optimizer() + if self.step_scheduler is None: + self.step_scheduler = self.build_step_scheduler(start_epoch) + if self.epoch_scheduler is None: + self.epoch_scheduler = self.build_epoch_scheduler(start_epoch) + + if self.is_ddp_master: + self.setup_tensorboard() + elif self.distributed: + if self.ddp_start_method == 'fork': + self.setup_data() + else: + # data + self.setup_data() + # model + self.load_checkpoint() + # optimization + start_epoch = self.get_start_epoch() + self.setup_loss(loss_def_path=loss_def_path) + if self.opt is None: + self.opt = self.build_optimizer() + if self.step_scheduler is None: + self.step_scheduler = self.build_step_scheduler(start_epoch) + if self.epoch_scheduler is None: + self.epoch_scheduler = self.build_epoch_scheduler(start_epoch) + + self.setup_tensorboard() def get_start_epoch(self) -> int: """Get start epoch. @@ -873,6 +1182,9 @@ def setup_model(self, if self.model is None: self.model = self.build_model(model_def_path=model_def_path) self.model.to(device=self.device) + if self.is_ddp_process: + self.model = nn.SyncBatchNorm.convert_sync_batchnorm(self.model) + self.model = DDP(self.model, device_ids=[self.ddp_rank]) self.load_init_weights(model_weights_path=model_weights_path) def build_model(self, model_def_path: Optional[str] = None) -> nn.Module: @@ -888,23 +1200,39 @@ def build_model(self, model_def_path: Optional[str] = None) -> nn.Module: num_classes=cfg.data.num_classes, in_channels=in_channels, save_dir=self.modules_dir, - hubconf_dir=model_def_path) + hubconf_dir=model_def_path, + is_ddp_process=self.is_ddp_process, + is_ddp_master=self.is_ddp_master) return model - def setup_data(self): + def setup_data(self, distributed: Optional[bool] = None): """Set datasets and dataLoaders for train, validation, and test sets. """ + if distributed is None: + distributed = self.distributed + if self.train_ds is None or self.valid_ds is None: - train_ds, valid_ds, test_ds = self.build_datasets() + if distributed: + if self.is_ddp_master: + train_ds, valid_ds, test_ds = self.build_datasets() + dist.barrier() + if not self.is_ddp_master: + train_ds, valid_ds, test_ds = self.build_datasets() + dist.barrier() + else: + train_ds, valid_ds, test_ds = self.build_datasets() if self.train_ds is None: self.train_ds = train_ds if self.valid_ds is None: self.valid_ds = valid_ds if self.test_ds is None: self.test_ds = test_ds - self.train_dl, self.valid_dl, self.test_dl = self.build_dataloaders() + + self.train_dl, self.valid_dl, self.test_dl = self.build_dataloaders( + distributed=distributed) def build_datasets(self) -> Tuple['Dataset', 'Dataset', 'Dataset']: + """Build Datasets for train, validation, and test splits.""" log.info(f'Building datasets ...') cfg = self.cfg train_ds, val_ds, test_ds = self.cfg.data.build( @@ -913,46 +1241,76 @@ def build_datasets(self) -> Tuple['Dataset', 'Dataset', 'Dataset']: test_mode=cfg.test_mode) return train_ds, val_ds, test_ds - def build_dataloaders(self) -> Tuple[DataLoader, DataLoader, DataLoader]: - """Set the DataLoaders for train, validation, and test sets.""" + def build_dataset(self, split: Literal['train', 'valid', 'test'] + ) -> Tuple['Dataset', 'Dataset', 'Dataset']: + """Build Dataset for split.""" + log.info('Building %s dataset ...', split) + cfg = self.cfg + ds = cfg.data.build_dataset( + split=split, + tmp_dir=self.tmp_dir, + overfit_mode=cfg.overfit_mode, + test_mode=cfg.test_mode) + return ds + + def build_dataloaders(self, distributed: Optional[bool] = None + ) -> Tuple[DataLoader, DataLoader, DataLoader]: + """Build DataLoaders for train, validation, and test splits.""" + if distributed is None: + distributed = self.distributed + + train_dl = self.build_dataloader('train', distributed=distributed) + val_dl = self.build_dataloader('valid', distributed=distributed) + + test_dl = None + if self.test_ds is not None and len(self.test_ds) > 0: + test_dl = self.build_dataloader('test', distributed=distributed) + + return train_dl, val_dl, test_dl + def build_dataloader(self, + split: Literal['train', 'valid', 'test'], + distributed: bool = False) -> DataLoader: + """Build DataLoader for split.""" + ds = self.get_dataset(split) + if ds is None: + ds = self.build_dataset(split) batch_sz = self.cfg.solver.batch_sz num_workers = self.cfg.data.num_workers collate_fn = self.get_collate_fn() + sampler = self.build_sampler(split, distributed=distributed) + + if distributed: + world_sz = self.ddp_world_size + if world_sz is None: + raise ValueError('World size not set. ' + 'Cannot determine per-process batch size.') + if world_sz > batch_sz: + raise ValueError(f'World size ({world_sz}) is greater ' + f'than total batch size ({batch_sz}).') + batch_sz //= world_sz + log.debug('Per GPU batch size: %d', batch_sz) - train_sampler = self.get_train_sampler(self.train_ds) - train_shuffle = train_sampler is None - # batchnorm layers expect batch size > 1 during training - train_drop_last = (len(self.train_ds) % batch_sz) == 1 - train_dl = DataLoader( - self.train_ds, - batch_size=batch_sz, - shuffle=train_shuffle, - drop_last=train_drop_last, - num_workers=num_workers, - pin_memory=True, - collate_fn=collate_fn, - sampler=train_sampler) - - val_dl = DataLoader( - self.valid_ds, + args = dict( batch_size=batch_sz, - shuffle=True, num_workers=num_workers, collate_fn=collate_fn, - pin_memory=True) + pin_memory=True, + multiprocessing_context='fork' if distributed else None, + ) - test_dl = None - if self.test_ds is not None and len(self.test_ds) > 0: - test_dl = DataLoader( - self.test_ds, - batch_size=batch_sz, - shuffle=True, - num_workers=num_workers, - collate_fn=collate_fn, - pin_memory=True) + if sampler is not None: + args['sampler'] = sampler + else: + if split == 'train': + args['shuffle'] = True + # batchnorm layers expect batch size > 1 during training + args['drop_last'] = (len(ds) % batch_sz) == 1 + else: + args['shuffle'] = False - return train_dl, val_dl, test_dl + dl = DataLoader(ds, **args) + return dl def get_collate_fn(self) -> Optional[callable]: """Returns a custom collate_fn to use in DataLoader. @@ -963,9 +1321,27 @@ def get_collate_fn(self) -> Optional[callable]: """ return None - def get_train_sampler(self, train_ds: 'Dataset') -> Optional['Sampler']: - """Return an optional sampler for the training dataloader.""" - return None + def build_sampler(self, + split: Literal['train', 'valid', 'test'], + distributed: bool = False) -> Optional['Sampler']: + """Return an optional sampler for the split's dataloader.""" + split = split.lower() + sampler = None + if split == 'train': + if distributed: + sampler = DistributedSampler( + self.train_ds, + shuffle=True, + num_replicas=self.ddp_world_size, + rank=self.ddp_rank) + elif split == 'valid': + if distributed: + sampler = DistributedSampler( + self.valid_ds, + shuffle=False, + num_replicas=self.ddp_world_size, + rank=self.ddp_rank) + return sampler def setup_loss(self, loss_def_path: Optional[str] = None) -> None: """Setup self.loss. @@ -1013,7 +1389,7 @@ def get_visualizer_class(self) -> Type[Visualizer]: """Returns a Visualizer class object for plotting data samples.""" def plot_predictions(self, - split: str, + split: Literal['train', 'valid', 'test'], batch_limit: Optional[int] = None, show: bool = False): """Plot predictions for a split. @@ -1111,19 +1487,28 @@ def save_model_bundle(self, export_onnx: bool = True): def _bundle_model(self, model_bundle_dir: str, export_onnx: bool = True) -> None: - """Save model weights and copy them to bundle dir..""" - # pytorch + """Save model weights and copy them to bundle dir.""" + model_not_set = self.model is None + if model_not_set: + self.setup_model( + model_weights_path=self.init_model_weights_path, + model_def_path=self.init_model_def_path) + self.load_checkpoint() + path = join(model_bundle_dir, BUNDLE_MODEL_WEIGHTS_FILENAME) if file_exists(self.last_model_weights_path): shutil.copyfile(self.last_model_weights_path, path) else: - torch.save(self.model.state_dict(), path) + self.save_weights(path) # ONNX if export_onnx: path = join(model_bundle_dir, BUNDLE_MODEL_ONNX_FILENAME) self.export_to_onnx(path) + if model_not_set: + self.model = None + def export_to_onnx(self, path: str, model: Optional['nn.Module'] = None, @@ -1153,19 +1538,19 @@ def export_to_onnx(self, if model is None: model = self.model + if isinstance(model, DDP): + model = model.module + training_state = model.training model.eval() if sample_input is None: - for split in ['train', 'valid', 'test']: - dl = self.get_dataloader(split) - if dl is not None: - break - else: - raise ValueError('sample_input not provided and Learner does ' - 'not have a DataLoader to get sample input ' - 'from.') + dl = self.valid_dl + if dl is None: + dl = self.build_dataloader('valid') sample_input, _ = next(iter(dl)) + + torch.cuda.empty_cache() sample_input = self.to_device(sample_input, self.device) args = dict( @@ -1232,6 +1617,17 @@ def _bundle_transforms(self, model_bundle_dir: str) -> None: ######### # Misc. ######### + def reduce_distributed_metrics(self, metrics: dict): + for k in metrics.keys(): + v = metrics[k] + if isinstance(v, (float, int)): + v = torch.tensor(v, device=self.device) + if isinstance(v, Tensor): + dist.reduce(v, dst=0, op=dist.ReduceOp.SUM) + if self.is_ddp_master: + metrics[k] = (v / self.ddp_world_size).item() + return metrics + def post_forward(self, x: Any) -> Any: """Post process output of call to model(). @@ -1267,7 +1663,23 @@ def to_device(self, x: Any, device: str) -> Any: else: return x.to(device) - def get_dataloader(self, split: str) -> DataLoader: + def get_dataset(self, split: Literal['train', 'valid', 'test'] + ) -> Optional[DataLoader]: + """Get the Dataset for a split. + + Args: + split: a split name which can be train, valid, or test + """ + if split == 'train': + return self.train_ds + if split == 'valid': + return self.valid_ds + if split == 'test': + return self.test_ds + raise ValueError(f'{split} is not a valid split') + + def get_dataloader(self, + split: Literal['train', 'valid', 'test']) -> DataLoader: """Get the DataLoader for a split. Args: @@ -1318,6 +1730,13 @@ def load_init_weights(self, log.info(f'Loading model weights from: {uri}') self.load_weights(uri=uri, **args) + def save_weights(self, path: str): + """Save model weights to a local file.""" + model = self.model + if isinstance(model, DDP): + model = model.module + torch.save(model.state_dict(), path) + def load_weights(self, uri: str, **kwargs) -> None: """Load model weights from a file. @@ -1326,7 +1745,10 @@ def load_weights(self, uri: str, **kwargs) -> None: **kwargs: Extra args for :meth:`nn.Module.load_state_dict`. """ weights_path = download_if_needed(uri) - self.model.load_state_dict( + model = self.model + if isinstance(model, DDP): + model = model.module + model.load_state_dict( torch.load(weights_path, map_location=self.device), **kwargs) def load_checkpoint(self): @@ -1379,10 +1801,10 @@ def run_tensorboard(self): def stop_tensorboard(self): """Stop TB logging and server if it's running.""" - if self.cfg.log_tensorboard: + if self.tb_writer is not None: self.tb_writer.close() - if self.cfg.run_tensorboard: - self.tb_process.terminate() + if self.tb_process is not None: + self.tb_process.terminate() @property def onnx_mode(self) -> bool: diff --git a/rastervision_pytorch_learner/rastervision/pytorch_learner/object_detection_learner.py b/rastervision_pytorch_learner/rastervision/pytorch_learner/object_detection_learner.py index 1feea898fe..d185738678 100644 --- a/rastervision_pytorch_learner/rastervision/pytorch_learner/object_detection_learner.py +++ b/rastervision_pytorch_learner/rastervision/pytorch_learner/object_detection_learner.py @@ -6,6 +6,7 @@ import numpy as np import torch +import torch.distributed as dist from rastervision.pytorch_learner.learner import Learner from rastervision.pytorch_learner.object_detection_utils import ( @@ -90,6 +91,23 @@ def validate_end(self, outputs): outs.extend(o['outs']) ys.extend(o['ys']) num_class_ids = len(self.cfg.data.class_names) + + if self.is_ddp_process: + is_master = self.is_ddp_master + all_outs = [None] * self.ddp_world_size + all_ys = [None] * self.ddp_world_size + dist.gather_object( + outs, + object_gather_list=(all_outs if is_master else None), + dst=0) + dist.gather_object( + ys, object_gather_list=(all_ys if is_master else None), dst=0) + if not is_master: + return {} + outs = sum(all_outs, []) + ys = sum(all_ys, []) + + log.info(f'{self.ddp_rank} at coco eval') coco_eval = compute_coco_eval(outs, ys, num_class_ids) metrics = {'mAP': 0.0, 'mAP50': 0.0} diff --git a/rastervision_pytorch_learner/rastervision/pytorch_learner/regression_learner.py b/rastervision_pytorch_learner/rastervision/pytorch_learner/regression_learner.py index 1ba0801726..4a837b896c 100644 --- a/rastervision_pytorch_learner/rastervision/pytorch_learner/regression_learner.py +++ b/rastervision_pytorch_learner/rastervision/pytorch_learner/regression_learner.py @@ -76,8 +76,8 @@ def validate_step(self, batch, batch_nb): def prob_to_pred(self, x): return x - def eval_model(self, split): - super().eval_model(split) + def _validate(self, split): + super()._validate(split) y, out = self.predict_dataloader( self.get_dataloader(split), return_format='yz', raw_out=False) diff --git a/rastervision_pytorch_learner/rastervision/pytorch_learner/semantic_segmentation_learner.py b/rastervision_pytorch_learner/rastervision/pytorch_learner/semantic_segmentation_learner.py index 98b69bf21b..03cb0be1eb 100644 --- a/rastervision_pytorch_learner/rastervision/pytorch_learner/semantic_segmentation_learner.py +++ b/rastervision_pytorch_learner/rastervision/pytorch_learner/semantic_segmentation_learner.py @@ -5,6 +5,7 @@ import torch from torch.nn import functional as F +import torch.distributed as dist from rastervision.pytorch_learner.learner import Learner from rastervision.pytorch_learner.utils import ( @@ -44,6 +45,13 @@ def validate_step(self, batch, batch_ind): def validate_end(self, outputs): metrics = aggregate_metrics(outputs, exclude_keys={'conf_mat'}) conf_mat = sum([o['conf_mat'] for o in outputs]) + + if self.is_ddp_process: + metrics = self.reduce_distributed_metrics(metrics) + dist.reduce(conf_mat, dst=0, op=dist.ReduceOp.SUM) + if not self.is_ddp_master: + return metrics + conf_mat_metrics = compute_conf_mat_metrics(conf_mat, self.cfg.data.class_names) metrics.update(conf_mat_metrics)