From 81f7d0b8ec8325860dad76d39a30c1ea33f6821a Mon Sep 17 00:00:00 2001 From: mmcs-work <28564860+mmcs-work@users.noreply.github.com> Date: Tue, 19 Sep 2023 00:21:37 +0200 Subject: [PATCH] Fixes #1720 (https://github.com/azavea/raster-vision/pull/1921#issuecomment-1723698355 - first point) --- .../rastervision/pytorch_learner/learner.py | 20 ++++++++----------- .../pytorch_learner/learner_config.py | 7 +++++++ 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/rastervision_pytorch_learner/rastervision/pytorch_learner/learner.py b/rastervision_pytorch_learner/rastervision/pytorch_learner/learner.py index 4d79434bb5..7007264b9b 100644 --- a/rastervision_pytorch_learner/rastervision/pytorch_learner/learner.py +++ b/rastervision_pytorch_learner/rastervision/pytorch_learner/learner.py @@ -45,6 +45,7 @@ warnings.filterwarnings('ignore') +CHECKPOINTS_DIRNAME = 'checkpoints' MODULES_DIRNAME = 'modules' TRANSFORMS_DIRNAME = 'custom_albumentations_transforms' BUNDLE_MODEL_WEIGHTS_FILENAME = 'model.pth' @@ -88,8 +89,7 @@ def __init__(self, model_weights_path: Optional[str] = None, model_def_path: Optional[str] = None, loss_def_path: Optional[str] = None, - training: bool = True, - save_all_checkpoints: bool = False): + training: bool = True): """Constructor. Args: @@ -135,11 +135,6 @@ def __init__(self, model will be put into eval mode. If True, the training apparatus will be set up and the model will be put into training mode. Defaults to True. - save_all_checkpoints (bool, optional): If True, all checkpoints - would be saved. The latest checkpoint would be saved as - `last-model.pth`. The checkpoints prior to last epoch are stored - as `model-ckpt-epoch-{N}.pth` where `N` is the epoch number. Defaults - to False. """ self.cfg = cfg @@ -163,7 +158,6 @@ def __init__(self, self.opt = optimizer self.epoch_scheduler = epoch_scheduler self.step_scheduler = step_scheduler - self.save_all_checkpoints = save_all_checkpoints # --------------------------- # Set URIs @@ -182,7 +176,6 @@ def __init__(self, else: self.output_dir = output_dir self.model_bundle_uri = join(self.output_dir, 'model-bundle.zip') - if is_local(self.output_dir): self.output_dir_local = self.output_dir make_dir(self.output_dir_local) @@ -195,6 +188,9 @@ def __init__(self, log.info(f'Remote output dir: {self.output_dir}') self.modules_dir = join(self.output_dir, MODULES_DIRNAME) + self.checkpoints_dir_local = join(self.output_dir_local, CHECKPOINTS_DIRNAME) + make_dir(self.checkpoints_dir_local) + # --------------------------- self._onnx_mode = False self.setup_model( @@ -529,10 +525,10 @@ def on_epoch_end(self, curr_epoch: int, metrics: MetricDict) -> None: self.tb_writer.add_scalar(key, val, curr_epoch) self.tb_writer.flush() - if self.save_all_checkpoints and curr_epoch > 0: + if self.cfg.save_all_checkpoints and curr_epoch > 0: checkpoint_name = f'model-ckpt-epoch-{curr_epoch - 1}.pth' - checkpoint_path = join(self.output_dir_local, checkpoint_name) - shutil.copy(self.last_model_weights_path, checkpoint_path) + checkpoint_path = join(self.checkpoints_dir_local, checkpoint_name) + shutil.move(self.last_model_weights_path, checkpoint_path) torch.save(self.model.state_dict(), self.last_model_weights_path) diff --git a/rastervision_pytorch_learner/rastervision/pytorch_learner/learner_config.py b/rastervision_pytorch_learner/rastervision/pytorch_learner/learner_config.py index 30d4cb5eb5..52bfbff82b 100644 --- a/rastervision_pytorch_learner/rastervision/pytorch_learner/learner_config.py +++ b/rastervision_pytorch_learner/rastervision/pytorch_learner/learner_config.py @@ -1377,6 +1377,13 @@ class LearnerConfig(Config): False, description='run Tensorboard server during training') output_uri: Optional[str] = Field( None, description='URI of where to save output') + save_all_checkpoints: bool = Field( + False, + description= + ('If True, all checkpoints would be saved. The latest checkpoint ' + 'would be saved as `last-model.pth`. The checkpoints prior to ' + 'last epoch are stored as `model-ckpt-epoch-{N}.pth` where `N` ' + 'is the epoch number.')) @validator('run_tensorboard') def validate_run_tensorboard(cls, v: bool, values: dict) -> bool: