Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Save model weights for each epoch 1720 #1921

Merged

Conversation

mmcs-work
Copy link
Contributor

@mmcs-work mmcs-work commented Sep 17, 2023

Overview

Save model weights for each epoch.

Notes

From #1720, proposition 1 is followed. Last checkpoint is still kept as last-model.pth. Instead of overwriting, it copies the previous checkpoint to model-ckpt-epoch-{N}.pth where N is the epoch number.

Testing Instructions

  • From either running the notebooks or the unittests where Learner is extended. The option save_all_checkpoints can be added as in
learner = SemanticSegmentationLearner(
    cfg=learner_cfg,
    output_dir='./train-demo/',
    model=model,
    train_ds=train_ds,
    valid_ds=val_ds,
    save_all_checkpoints=True
)

Closes #1720

@mmcs-work mmcs-work force-pushed the save-model-weights-for-each-epoch-1720 branch from 07b943d to c6fa61d Compare September 18, 2023 06:53
@AdeelH
Copy link
Collaborator

AdeelH commented Sep 18, 2023

This is a simple and elegant solution to this problem. Really appreciate the high-quality PR!

Some minor changes I would suggest are:

  • Make save_all_checkpoints a field in LearnerConfig instead of an argument to Learner.
  • Store the checkpoints in a checkpoints subdirectory under the Learner's output directory to avoid cluttering up the latter. You can use the make_dir() which is already available in learner.py.

Additionally, if we want to make this field usable in the RV-as-a-framework context, we would also want to add it to PyTorchLearnerBackendConfig and then modify the get_learner_config() methods of all 3 of its subclasses to pass it to the respective LearnerConfig subclasses (identical to how log_tensorboard and run_tensorboard are handled). If you are only using RV as a library then this is probably not useful/relevant to you, so it is totally okay if you want to skip this in this PR!

mmcs-work added a commit to mmcs-work/raster-vision that referenced this pull request Sep 18, 2023
mmcs-work added a commit to mmcs-work/raster-vision that referenced this pull request Sep 18, 2023
Adds `save_all_checkpoints` to `PyTorchlearnerBackendConfig`
Adds `save_all_checkpoints` to `PyTorchlearnerBackendConfig`
@mmcs-work mmcs-work force-pushed the save-model-weights-for-each-epoch-1720 branch from a3bc602 to a1e60e6 Compare September 18, 2023 23:56
@AdeelH
Copy link
Collaborator

AdeelH commented Sep 19, 2023

Looks like there are linter errors. You can run scripts/style_tests locally to find those. And also scripts/format_code to format your code correctly.

If you are using VS Code, you can install the flake8 extension and set yapf as the formatter to ensure that your code is always compliant. Be sure to install the flake8 and yapf versions from requirements-dev.txt.

@codecov
Copy link

codecov bot commented Sep 19, 2023

Codecov Report

Merging #1921 (e04168d) into master (30fd26f) will decrease coverage by 0.02%.
The diff coverage is 66.66%.

@@            Coverage Diff             @@
##           master    #1921      +/-   ##
==========================================
- Coverage   82.49%   82.47%   -0.02%     
==========================================
  Files         190      190              
  Lines        9316     9325       +9     
==========================================
+ Hits         7685     7691       +6     
- Misses       1631     1634       +3     
Files Changed Coverage Δ
...orch_backend/pytorch_chip_classification_config.py 62.96% <ø> (ø)
...pytorch_backend/pytorch_object_detection_config.py 61.53% <ø> (ø)
...ch_backend/pytorch_semantic_segmentation_config.py 62.96% <ø> (ø)
...ch_learner/rastervision/pytorch_learner/learner.py 75.66% <57.14%> (-0.24%) ⬇️
.../pytorch_backend/pytorch_learner_backend_config.py 68.75% <100.00%> (+0.66%) ⬆️
...ner/rastervision/pytorch_learner/learner_config.py 83.12% <100.00%> (+0.03%) ⬆️

Copy link
Collaborator

@AdeelH AdeelH left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me. Thank you!

@AdeelH AdeelH merged commit 4ee8af0 into azavea:master Sep 19, 2023
1 check passed
@mmcs-work
Copy link
Contributor Author

This looks good to me. Thank you!

Thanks for your help.

@AdeelH AdeelH added the needs-backport This PR needs to be backported to release branches label Sep 22, 2023
AdeelH pushed a commit to AdeelH/raster-vision that referenced this pull request Sep 22, 2023
* Save model's weights for each epoch azavea#1720

* Fixes azavea#1720 (azavea#1921 (comment) - first point) Adds `save_all_checkpoints` to `LearnerConfig`

* Fixes azavea#1720 (azavea#1921 (comment) (comment) - second point)
Adds `save_all_checkpoints` to `PyTorchlearnerBackendConfig`

* Formats code (yapf) azavea#1720
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
needs-backport This PR needs to be backported to release branches
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Save model's weights for each epoch
2 participants