Skip to content

Commit

Permalink
Fixes #1720 (#1921 (comment) (comment) - second point)
Browse files Browse the repository at this point in the history
Adds `save_all_checkpoints` to `PyTorchlearnerBackendConfig`
  • Loading branch information
mmcs-work committed Sep 18, 2023
1 parent 0c6871d commit a1e60e6
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ def get_learner_config(self, pipeline):
test_mode=self.test_mode,
output_uri=pipeline.train_uri,
log_tensorboard=self.log_tensorboard,
run_tensorboard=self.run_tensorboard)
run_tensorboard=self.run_tensorboard,
save_all_checkpoints=self.save_all_checkpoints)
learner.update()
learner.validate_config()
return learner
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,13 @@ class PyTorchLearnerBackendConfig(BackendConfig):
('This field is passed along to the LearnerConfig which is returned by '
'get_learner_config(). For more info, see the docs for'
'pytorch_learner.learner_config.LearnerConfig.test_mode.'))
save_all_checkpoints: bool = Field(
False,
description=
('If True, all checkpoints would be saved. The latest checkpoint '
'would be saved as `last-model.pth`. The checkpoints prior to '
'last epoch are stored as `model-ckpt-epoch-{N}.pth` where `N` '
'is the epoch number.'))

def get_bundle_filenames(self):
return ['model-bundle.zip']
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ def get_learner_config(self, pipeline):
test_mode=self.test_mode,
output_uri=pipeline.train_uri,
log_tensorboard=self.log_tensorboard,
run_tensorboard=self.run_tensorboard)
run_tensorboard=self.run_tensorboard,
save_all_checkpoints=self.save_all_checkpoints)
learner.update()
return learner

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ def get_learner_config(self, pipeline):
test_mode=self.test_mode,
output_uri=pipeline.train_uri,
log_tensorboard=self.log_tensorboard,
run_tensorboard=self.run_tensorboard)
run_tensorboard=self.run_tensorboard,
save_all_checkpoints=self.save_all_checkpoints)
learner.update()
learner.validate_config()
return learner
Expand Down

0 comments on commit a1e60e6

Please sign in to comment.