Skip to content

Commit

Permalink
Fixes azavea#1720 (azavea#1921 (comment) - first point)
Browse files Browse the repository at this point in the history
  • Loading branch information
mmcs-work committed Sep 18, 2023
1 parent c6fa61d commit 81f7d0b
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@

warnings.filterwarnings('ignore')

CHECKPOINTS_DIRNAME = 'checkpoints'
MODULES_DIRNAME = 'modules'
TRANSFORMS_DIRNAME = 'custom_albumentations_transforms'
BUNDLE_MODEL_WEIGHTS_FILENAME = 'model.pth'
Expand Down Expand Up @@ -88,8 +89,7 @@ def __init__(self,
model_weights_path: Optional[str] = None,
model_def_path: Optional[str] = None,
loss_def_path: Optional[str] = None,
training: bool = True,
save_all_checkpoints: bool = False):
training: bool = True):
"""Constructor.
Args:
Expand Down Expand Up @@ -135,11 +135,6 @@ def __init__(self,
model will be put into eval mode. If True, the training
apparatus will be set up and the model will be put into
training mode. Defaults to True.
save_all_checkpoints (bool, optional): If True, all checkpoints
would be saved. The latest checkpoint would be saved as
`last-model.pth`. The checkpoints prior to last epoch are stored
as `model-ckpt-epoch-{N}.pth` where `N` is the epoch number. Defaults
to False.
"""
self.cfg = cfg

Expand All @@ -163,7 +158,6 @@ def __init__(self,
self.opt = optimizer
self.epoch_scheduler = epoch_scheduler
self.step_scheduler = step_scheduler
self.save_all_checkpoints = save_all_checkpoints

# ---------------------------
# Set URIs
Expand All @@ -182,7 +176,6 @@ def __init__(self,
else:
self.output_dir = output_dir
self.model_bundle_uri = join(self.output_dir, 'model-bundle.zip')

if is_local(self.output_dir):
self.output_dir_local = self.output_dir
make_dir(self.output_dir_local)
Expand All @@ -195,6 +188,9 @@ def __init__(self,
log.info(f'Remote output dir: {self.output_dir}')

self.modules_dir = join(self.output_dir, MODULES_DIRNAME)
self.checkpoints_dir_local = join(self.output_dir_local, CHECKPOINTS_DIRNAME)
make_dir(self.checkpoints_dir_local)

# ---------------------------
self._onnx_mode = False
self.setup_model(
Expand Down Expand Up @@ -529,10 +525,10 @@ def on_epoch_end(self, curr_epoch: int, metrics: MetricDict) -> None:
self.tb_writer.add_scalar(key, val, curr_epoch)
self.tb_writer.flush()

if self.save_all_checkpoints and curr_epoch > 0:
if self.cfg.save_all_checkpoints and curr_epoch > 0:
checkpoint_name = f'model-ckpt-epoch-{curr_epoch - 1}.pth'
checkpoint_path = join(self.output_dir_local, checkpoint_name)
shutil.copy(self.last_model_weights_path, checkpoint_path)
checkpoint_path = join(self.checkpoints_dir_local, checkpoint_name)
shutil.move(self.last_model_weights_path, checkpoint_path)

torch.save(self.model.state_dict(), self.last_model_weights_path)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1377,6 +1377,13 @@ class LearnerConfig(Config):
False, description='run Tensorboard server during training')
output_uri: Optional[str] = Field(
None, description='URI of where to save output')
save_all_checkpoints: bool = Field(
False,
description=
('If True, all checkpoints would be saved. The latest checkpoint '
'would be saved as `last-model.pth`. The checkpoints prior to '
'last epoch are stored as `model-ckpt-epoch-{N}.pth` where `N` '
'is the epoch number.'))

@validator('run_tensorboard')
def validate_run_tensorboard(cls, v: bool, values: dict) -> bool:
Expand Down

0 comments on commit 81f7d0b

Please sign in to comment.