Skip to content

Commit

Permalink
Save model weights for each epoch 1720 (azavea#1921)
Browse files Browse the repository at this point in the history
* Save model's weights for each epoch azavea#1720

* Fixes azavea#1720 (azavea#1921 (comment) - first point) Adds `save_all_checkpoints` to `LearnerConfig`

* Fixes azavea#1720 (azavea#1921 (comment) (comment) - second point)
Adds `save_all_checkpoints` to `PyTorchlearnerBackendConfig`

* Formats code (yapf) azavea#1720
  • Loading branch information
mmcs-work authored and AdeelH committed Sep 22, 2023
1 parent 48999ab commit 0368421
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ def get_learner_config(self, pipeline):
test_mode=self.test_mode,
output_uri=pipeline.train_uri,
log_tensorboard=self.log_tensorboard,
run_tensorboard=self.run_tensorboard)
run_tensorboard=self.run_tensorboard,
save_all_checkpoints=self.save_all_checkpoints)
learner.update()
learner.validate_config()
return learner
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,13 @@ class PyTorchLearnerBackendConfig(BackendConfig):
('This field is passed along to the LearnerConfig which is returned by '
'get_learner_config(). For more info, see the docs for'
'pytorch_learner.learner_config.LearnerConfig.test_mode.'))
save_all_checkpoints: bool = Field(
False,
description=(
'If True, all checkpoints would be saved. The latest checkpoint '
'would be saved as `last-model.pth`. The checkpoints prior to '
'last epoch are stored as `model-ckpt-epoch-{N}.pth` where `N` '
'is the epoch number.'))

def get_bundle_filenames(self):
return ['model-bundle.zip']
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ def get_learner_config(self, pipeline):
test_mode=self.test_mode,
output_uri=pipeline.train_uri,
log_tensorboard=self.log_tensorboard,
run_tensorboard=self.run_tensorboard)
run_tensorboard=self.run_tensorboard,
save_all_checkpoints=self.save_all_checkpoints)
learner.update()
return learner

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ def get_learner_config(self, pipeline):
test_mode=self.test_mode,
output_uri=pipeline.train_uri,
log_tensorboard=self.log_tensorboard,
run_tensorboard=self.run_tensorboard)
run_tensorboard=self.run_tensorboard,
save_all_checkpoints=self.save_all_checkpoints)
learner.update()
learner.validate_config()
return learner
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@

warnings.filterwarnings('ignore')

CHECKPOINTS_DIRNAME = 'checkpoints'
MODULES_DIRNAME = 'modules'
TRANSFORMS_DIRNAME = 'custom_albumentations_transforms'
BUNDLE_MODEL_WEIGHTS_FILENAME = 'model.pth'
Expand Down Expand Up @@ -175,7 +176,6 @@ def __init__(self,
else:
self.output_dir = output_dir
self.model_bundle_uri = join(self.output_dir, 'model-bundle.zip')

if is_local(self.output_dir):
self.output_dir_local = self.output_dir
make_dir(self.output_dir_local)
Expand All @@ -188,6 +188,10 @@ def __init__(self,
log.info(f'Remote output dir: {self.output_dir}')

self.modules_dir = join(self.output_dir, MODULES_DIRNAME)
self.checkpoints_dir_local = join(self.output_dir_local,
CHECKPOINTS_DIRNAME)
make_dir(self.checkpoints_dir_local)

# ---------------------------
self._onnx_mode = False
self.setup_model(
Expand Down Expand Up @@ -522,6 +526,11 @@ def on_epoch_end(self, curr_epoch: int, metrics: MetricDict) -> None:
self.tb_writer.add_scalar(key, val, curr_epoch)
self.tb_writer.flush()

if self.cfg.save_all_checkpoints and curr_epoch > 0:
checkpoint_name = f'model-ckpt-epoch-{curr_epoch - 1}.pth'
checkpoint_path = join(self.checkpoints_dir_local, checkpoint_name)
shutil.move(self.last_model_weights_path, checkpoint_path)

torch.save(self.model.state_dict(), self.last_model_weights_path)

if (curr_epoch + 1) % self.cfg.solver.sync_interval == 0:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1377,6 +1377,13 @@ class LearnerConfig(Config):
False, description='run Tensorboard server during training')
output_uri: Optional[str] = Field(
None, description='URI of where to save output')
save_all_checkpoints: bool = Field(
False,
description=(
'If True, all checkpoints would be saved. The latest checkpoint '
'would be saved as `last-model.pth`. The checkpoints prior to '
'last epoch are stored as `model-ckpt-epoch-{N}.pth` where `N` '
'is the epoch number.'))

@validator('run_tensorboard')
def validate_run_tensorboard(cls, v: bool, values: dict) -> bool:
Expand Down

0 comments on commit 0368421

Please sign in to comment.