Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to reuse .bin files #116

Open
wants to merge 7 commits into
base: divya/load-head-ckpt-inference
Choose a base branch
from

Conversation

gitttt-1234
Copy link
Contributor

@gitttt-1234 gitttt-1234 commented Nov 2, 2024

This PR introduces an option to enable/ disable the auto-deletion of .bin files (data chunks for training) generated by ld.optimie. Additionally, this provides the flexibility to load exitsing train and validation chunks into any training process, by passing the paths to the .bin folder to the ModelTrainer.train() function.

Summary by CodeRabbit

  • New Features

    • Enhanced configuration documentation for model training, including new parameters for data handling and learning rate scheduling.
    • Introduced new parameters such as chunk_size and user_instances_only across various configuration files.
    • Added support for specifying alternative checkpoint files for model inference.
  • Bug Fixes

    • Improved error handling for model loading and training processes.
  • Tests

    • Expanded test coverage for model loading and training configurations, ensuring robustness against various scenarios.

Copy link
Contributor

coderabbitai bot commented Nov 2, 2024

Important

Review skipped

Auto reviews are disabled on base/target branches other than the default branch.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Walkthrough

The pull request introduces extensive modifications to the configuration documentation and related YAML files for the sleap_nn.ModelTrainer class. Key changes include the restructuring of configuration files into three main sections: data_config, model_config, and trainer_config, with detailed parameter updates. New parameters for learning rate scheduling and early stopping have been added, and existing structures have been reorganized for clarity. Additionally, the changes enhance the configurability of model training processes and improve the handling of model checkpoints.

Changes

File Path Change Summary
docs/config.md Updated documentation for sleap_nn.ModelTrainer configuration, detailing data_config, model_config, and trainer_config sections. Added new parameters for learning rate scheduling and early stopping.
docs/config_bottomup.yaml Added chunk_size under data_config. Replaced lr_scheduler structure with scheduler set to ReduceLROnPlateau, nesting related parameters under reduce_lr_on_plateau.
docs/config_centroid.yaml Added chunk_size under data_config. Updated lr_scheduler to a nested structure under reduce_lr_on_plateau. Removed log_params from wandb configuration.
docs/config_topdown_centered_instance.yaml Added chunk_size under data_config and min_crop_size in preprocessing. Updated lr_scheduler to a nested structure under reduce_lr_on_plateau.
initial_config.yaml Introduced a new configuration structure with sections for data_config, model_config, and trainer_config. Added parameters for data handling, model architecture, and training settings, including learning rate scheduling.
sleap_nn/data/streaming_datasets.py Modified CenteredInstanceStreamingDataset to recalculate crop_hw based on input_scale during initialization.
sleap_nn/inference/predictors.py Enhanced Predictor class to include new parameters for loading alternative model checkpoints for backbone and head layers.
sleap_nn/training/model_trainer.py Refactored methods in ModelTrainer for better handling of input paths and model initialization. Expanded learning rate scheduler configuration to support multiple types.
tests/assets/minimal_instance/initial_config.yaml Added new parameters under data_config. Updated lr_scheduler configuration to a nested structure. Removed log_params.
tests/assets/minimal_instance/training_config.yaml Added new parameters under data_config. Updated lr_scheduler configuration to a nested structure. Removed log_params.
tests/assets/minimal_instance_bottomup/initial_config.yaml Added new parameters under data_config. Updated lr_scheduler configuration to a nested structure. Removed log_params.
tests/assets/minimal_instance_bottomup/training_config.yaml Added new parameters under data_config. Updated lr_scheduler configuration to a nested structure. Removed log_params.
tests/assets/minimal_instance_centroid/initial_config.yaml Added new parameters under data_config. Updated lr_scheduler configuration to a nested structure. Removed log_params.
tests/assets/minimal_instance_centroid/training_config.yaml Added new parameters under data_config. Updated lr_scheduler configuration to a nested structure. Removed log_params.
tests/fixtures/datasets.py Changed provider in data_config from "LabelsReaderDP" to "LabelsReader". Added new keys for user_instances_only and chunk_size. Restructured lr_scheduler to a nested format.
tests/inference/test_predictors.py Updated tests for Predictor class, adding parameters for loading different configurations and verifying loaded weights.
tests/training/test_model_trainer.py Modified tests for ModelTrainer, adding new tests and updating configurations for learning rate scheduling.

Possibly related PRs

Suggested reviewers

  • talmo

🐰 Hopping through the fields so bright,
New configs bring a joyful sight!
With chunk sizes and schedulers in play,
Our models will learn in a better way!
So let’s cheer for the changes made,
In the world of training, we’ve got it made! 🐇✨


Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media?

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

@gitttt-1234 gitttt-1234 changed the base branch from main to divya/load-head-ckpt-inference November 2, 2024 00:13
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Inline review comments failed to post. This is likely due to GitHub's limits when posting large numbers of comments.

Actionable comments posted: 20

🧹 Outside diff range and nitpick comments (35)
tests/assets/minimal_instance_centroid/initial_config.yaml (1)

78-85: Review learning rate scheduler parameters.

The ReduceLROnPlateau configuration uses very small values:

  • threshold: 1.0e-07 and min_lr: 1.0e-08 are quite low
  • patience: 5 with cooldown: 3 means LR changes can occur frequently

Consider if these values are appropriate for your model and training dynamics. Too-frequent LR changes might destabilize training.

Recommendations:

  1. Consider increasing the threshold to ~1e-4 for more stable training
  2. You might want to increase cooldown to prevent too-frequent LR changes
  3. Document why these specific values were chosen
tests/assets/minimal_instance_bottomup/initial_config.yaml (1)

70-71: Document .bin file deletion behavior.

Since this PR introduces the option to reuse .bin files, it would be helpful to add a comment or parameter that explicitly indicates whether .bin files will be automatically deleted after training. This would make the behavior more transparent to users.

Consider adding a parameter like delete_bin_files: true with appropriate documentation.

tests/assets/minimal_instance_centroid/training_config.yaml (1)

89-96: Consider separating LR scheduler changes into a different PR

The learning rate scheduler changes appear to be unrelated to the main objective of this PR (reusing .bin files). While the configuration looks well-tuned, it might be clearer to handle these changes in a separate PR focused on training optimizations.

The current settings are reasonable:

  • Reduction factor of 0.5
  • Patience of 5 epochs with 3 epochs cooldown
  • Minimum learning rate of 1e-8
tests/assets/minimal_instance_bottomup/training_config.yaml (2)

85-85: Document the bin_files_path usage.

This parameter enables the reuse of .bin files as described in the PR objectives. Consider:

  1. Adding a comment explaining the expected path format
  2. Documenting whether relative paths are supported
  3. Clarifying what happens when the path is invalid

Add a YAML comment above this line:

+  # Path to existing .bin files for reusing training data chunks. Set to null to generate new chunks.
   bin_files_path:

99-106: LGTM! Well-structured learning rate scheduler configuration.

The ReduceLROnPlateau configuration has sensible defaults:

  • Gradual reduction (factor: 0.5)
  • Reasonable patience (5 epochs) and cooldown (3 epochs)
  • Safe minimum learning rate (1e-8)

Consider adding comments to explain the threshold_mode options (abs vs rel) for future maintainers.

tests/assets/minimal_instance/training_config.yaml (2)

5-6: Document the new data configuration parameters.

New parameters have been added without documentation explaining their purpose and impact:

  • user_instances_only
  • chunk_size
  • min_crop_size

Please add documentation describing:

  • What each parameter controls
  • Expected values/ranges
  • Default behavior

Also applies to: 15-15


Line range hint 1-103: Consider splitting changes into focused PRs.

The current changes mix multiple concerns:

  1. .bin file reuse functionality (primary objective)
  2. Learning rate scheduler modifications
  3. Data preprocessing parameters

This makes the changes harder to review and maintain. Consider:

  1. Keeping only the .bin file reuse changes in this PR
  2. Moving LR scheduler and preprocessing changes to separate PRs
  3. Adding comprehensive documentation for all new parameters

Additionally, please add validation to ensure the bin_files_path points to a valid directory when provided.

docs/config_bottomup.yaml (2)

6-6: Add documentation for the chunk_size parameter.

Since this parameter is crucial for the new feature of managing training data chunks, please add a comment explaining its purpose, impact on training, and any constraints on its value.

 data_config:
   provider: LabelsReader
   train_labels_path: minimal_instance.pkg.slp
   val_labels_path: minimal_instance.pkg.slp
   user_instances_only: True
+  # Size of data chunks for training. Controls how many frames are processed
+  # together when generating .bin files
   chunk_size: 100

Line range hint 92-92: Document the bin_files_path parameter.

Since this PR introduces the ability to reuse .bin files, please add documentation explaining:

  1. The purpose of this parameter
  2. Expected path format
  3. Behavior when the path is empty vs. when it contains a value
   save_ckpt: true
   save_ckpt_path: min_inst_bottomup1
+  # Path to existing .bin files for reusing training data chunks
+  # Leave empty to generate new chunks, or specify path to reuse existing ones
   bin_files_path:
docs/config_topdown_centered_instance.yaml (3)

6-6: Document the chunk_size parameter's purpose and impact.

The newly added chunk_size parameter aligns with the PR's objective of managing training data chunks, but its purpose and impact should be documented for users.

Consider adding a comment explaining:

  • What this parameter controls
  • How it affects the training process
  • Any recommended values or constraints

107-114: LGTM! Well-structured learning rate scheduler configuration.

The ReduceLROnPlateau configuration is comprehensive with appropriate parameters for stable training. It pairs well with the early stopping configuration.

Consider adding inline comments explaining the purpose of each parameter, especially threshold_mode: abs vs. rel, to help users customize these values for their needs.


Line range hint 92-93: Add parameter to control .bin files deletion.

The PR aims to allow users to control the automatic deletion of .bin files, but this configuration is missing a parameter to enable/disable this feature.

Consider adding a parameter like retain_bin_files: false near the bin_files_path configuration to control this behavior.

  bin_files_path:
+ retain_bin_files: false  # Controls whether to keep .bin files after training
🧰 Tools
🪛 yamllint

[error] 8-8: trailing spaces

(trailing-spaces)


[error] 9-9: trailing spaces

(trailing-spaces)

initial_config.yaml (2)

77-77: Document the bin_files_path parameter.

This new parameter aligns with the PR objective to reuse .bin files, but needs documentation about:

  • Expected directory structure
  • File naming conventions
  • Any requirements for the binary files

Consider adding a comment above this parameter explaining its usage:

+  # Path to directory containing pre-generated .bin training chunks
+  # If provided, these chunks will be reused instead of generating new ones
   bin_files_path: null

73-73: Add warning about GPU determinism.

While setting a fixed seed is good for reproducibility, users should be aware that complete determinism requires additional settings when using GPUs.

Consider adding a warning comment:

+  # Note: For complete determinism with GPUs, additional PyTorch settings are required
   seed: 1000
docs/config_centroid.yaml (2)

6-6: Consider documenting chunk_size parameter and its implications.

While the chunk_size parameter has been added, its relationship to the .bin files and data chunking process should be documented in comments.

 data_config:
   provider: LabelsReader
   train_labels_path: minimal_instance.pkg.slp
   val_labels_path: minimal_instance.pkg.slp
   user_instances_only: True
-  chunk_size: 100
+  # Size of data chunks used for training. Affects the size of generated .bin files
+  chunk_size: 100

Line range hint 89-89: Add configuration for .bin file retention.

The PR aims to add an option to control .bin file deletion, but there's no visible configuration parameter for this feature. Consider adding a parameter to control this behavior.

   save_ckpt: true
   save_ckpt_path: 'min_inst_centroid'
   bin_files_path:
+  # Set to true to retain .bin files after training (default: false)
+  retain_bin_files: false
   resume_ckpt_path:
🧰 Tools
🪛 yamllint

[error] 8-8: trailing spaces

(trailing-spaces)


[error] 9-9: trailing spaces

(trailing-spaces)

sleap_nn/data/streaming_datasets.py (2)

154-155: Consider adding input validation for crop dimensions.

Since the crop dimensions are critical for proper functioning and are now calculated once during initialization, consider adding validation to ensure the scaled dimensions remain valid.

Add input validation:

 # Re-crop to original crop size
+if not all(x > 0 for x in self.crop_hw):
+    raise ValueError(f"Invalid crop dimensions after scaling: {self.crop_hw}. Check input_scale: {self.input_scale}")
 self.crop_hw = [int(x * self.input_scale) for x in self.crop_hw]

154-155: Document the relationship with .bin files.

Since this PR focuses on .bin file reuse, it would be helpful to document how the crop size affects the data chunks stored in .bin files.

Add documentation:

 # Re-crop to original crop size
+# Note: These dimensions affect the data chunks stored in .bin files.
+# When reusing .bin files, ensure consistent crop_hw and input_scale values
 self.crop_hw = [int(x * self.input_scale) for x in self.crop_hw]
docs/config.md (2)

Line range hint 1-188: Documentation missing for .bin files reuse feature.

The PR introduces the ability to reuse .bin files, but the documentation only mentions the bin_files_path parameter for saving them. Please add documentation for:

  • Parameters controlling automatic deletion of .bin files
  • How to specify existing .bin files for reuse in training

Would you like me to help draft the documentation for these new parameters?

🧰 Tools
🪛 LanguageTool

[uncategorized] ~179-~179: Loose punctuation mark.
Context: ...ReduceLROnPlateau". - step_lr: - step_size`: (int) Period...

(UNLIKELY_OPENING_PUNCTUATION)


[uncategorized] ~182-~182: Loose punctuation mark.
Context: ...*: 0.1. - reduce_lr_on_plateau: - threshold: (float) Thre...

(UNLIKELY_OPENING_PUNCTUATION)

🪛 Markdownlint

170-170: Expected: 4; Actual: 8
Unordered list indentation

(MD007, ul-indent)


171-171: Expected: 4; Actual: 8
Unordered list indentation

(MD007, ul-indent)


172-172: Expected: 4; Actual: 8
Unordered list indentation

(MD007, ul-indent)


173-173: Expected: 2; Actual: 4
Unordered list indentation

(MD007, ul-indent)


174-174: Expected: 2; Actual: 4
Unordered list indentation

(MD007, ul-indent)


175-175: Expected: 4; Actual: 8
Unordered list indentation

(MD007, ul-indent)


176-176: Expected: 4; Actual: 8
Unordered list indentation

(MD007, ul-indent)


177-177: Expected: 2; Actual: 4
Unordered list indentation

(MD007, ul-indent)


178-178: Expected: 4; Actual: 8
Unordered list indentation

(MD007, ul-indent)


179-179: Expected: 4; Actual: 8
Unordered list indentation

(MD007, ul-indent)


180-180: Expected: 6; Actual: 12
Unordered list indentation

(MD007, ul-indent)


181-181: Expected: 6; Actual: 12
Unordered list indentation

(MD007, ul-indent)


182-182: Expected: 4; Actual: 8
Unordered list indentation

(MD007, ul-indent)


183-183: Expected: 6; Actual: 12
Unordered list indentation

(MD007, ul-indent)


184-184: Expected: 6; Actual: 12
Unordered list indentation

(MD007, ul-indent)


185-185: Expected: 6; Actual: 12
Unordered list indentation

(MD007, ul-indent)


186-186: Expected: 6; Actual: 12
Unordered list indentation

(MD007, ul-indent)


187-187: Expected: 6; Actual: 12
Unordered list indentation

(MD007, ul-indent)


188-188: Expected: 6; Actual: 12
Unordered list indentation

(MD007, ul-indent)


189-189: Expected: 2; Actual: 4
Unordered list indentation

(MD007, ul-indent)


190-190: Expected: 4; Actual: 8
Unordered list indentation

(MD007, ul-indent)


191-191: Expected: 4; Actual: 8
Unordered list indentation

(MD007, ul-indent)


184-184: null
Spaces inside emphasis markers

(MD037, no-space-in-emphasis)


184-184: null
Spaces inside emphasis markers

(MD037, no-space-in-emphasis)


178-188: Enhance scheduler configuration documentation with examples.

The scheduler configuration is well-documented but could benefit from practical examples. Consider adding:

  1. Example configurations for both scheduler types
  2. Common use cases for each parameter
  3. Guidelines for choosing between StepLR and ReduceLROnPlateau

Example addition:

Example configurations:
```yaml
# StepLR: Reduce learning rate by half every 10 epochs
lr_scheduler:
  scheduler: "StepLR"
  step_lr:
    step_size: 10
    gamma: 0.5

# ReduceLROnPlateau: Reduce learning rate when validation loss plateaus
lr_scheduler:
  scheduler: "ReduceLROnPlateau"
  reduce_lr_on_plateau:
    patience: 5
    factor: 0.1
    threshold: 1e-4

<details>
<summary>🧰 Tools</summary>

<details>
<summary>🪛 LanguageTool</summary>

[uncategorized] ~179-~179: Loose punctuation mark.
Context: ...ReduceLROnPlateau"`.         - `step_lr`:             - `step_size`: (int) Period...

(UNLIKELY_OPENING_PUNCTUATION)

---

[uncategorized] ~182-~182: Loose punctuation mark.
Context: ...*: 0.1.         - `reduce_lr_on_plateau`:             - `threshold`: (float) Thre...

(UNLIKELY_OPENING_PUNCTUATION)

</details>
<details>
<summary>🪛 Markdownlint</summary>

178-178: Expected: 4; Actual: 8
Unordered list indentation

(MD007, ul-indent)

---

179-179: Expected: 4; Actual: 8
Unordered list indentation

(MD007, ul-indent)

---

180-180: Expected: 6; Actual: 12
Unordered list indentation

(MD007, ul-indent)

---

181-181: Expected: 6; Actual: 12
Unordered list indentation

(MD007, ul-indent)

---

182-182: Expected: 4; Actual: 8
Unordered list indentation

(MD007, ul-indent)

---

183-183: Expected: 6; Actual: 12
Unordered list indentation

(MD007, ul-indent)

---

184-184: Expected: 6; Actual: 12
Unordered list indentation

(MD007, ul-indent)

---

185-185: Expected: 6; Actual: 12
Unordered list indentation

(MD007, ul-indent)

---

186-186: Expected: 6; Actual: 12
Unordered list indentation

(MD007, ul-indent)

---

187-187: Expected: 6; Actual: 12
Unordered list indentation

(MD007, ul-indent)

---

188-188: Expected: 6; Actual: 12
Unordered list indentation

(MD007, ul-indent)

---

184-184: null
Spaces inside emphasis markers

(MD037, no-space-in-emphasis)

---

184-184: null
Spaces inside emphasis markers

(MD037, no-space-in-emphasis)

</details>

</details>

</blockquote></details>
<details>
<summary>sleap_nn/inference/predictors.py (3)</summary><blockquote>

`94-95`: **Add parameter validation for checkpoint paths.**

Consider adding validation to ensure that if `head_ckpt_path` is provided, `backbone_ckpt_path` must also be provided, as loading only head weights without backbone weights might lead to unexpected behavior.

```diff
 def from_model_paths(
     cls,
     model_paths: List[Text],
     backbone_ckpt_path: Optional[str] = None,
     head_ckpt_path: Optional[str] = None,
+    # Add validation
+    if head_ckpt_path is not None and backbone_ckpt_path is None:
+        raise ValueError("backbone_ckpt_path must be provided when head_ckpt_path is provided")

1622-1626: Improve parameter documentation in docstring.

The docstring for the new parameters could be more descriptive and include examples:

-        backbone_ckpt_path: (str) To run inference on any `.ckpt` other than `best.ckpt`
-                from the `model_paths` dir, the path to the `.ckpt` file should be passed here.
-        head_ckpt_path: (str) Path to `.ckpt` file if a different set of head layer weights
-                are to be used. If `None`, the `best.ckpt` from `model_paths` dir is used (or the ckpt
-                from `backbone_ckpt_path` if provided.)
+        backbone_ckpt_path: Optional path to a checkpoint file containing backbone weights.
+                If provided, these weights will be used instead of the backbone weights from
+                `best.ckpt`. This allows mixing weights from different checkpoints.
+                Example: "/path/to/backbone_v2.ckpt"
+        head_ckpt_path: Optional path to a checkpoint file containing head layer weights.
+                Can only be used if backbone_ckpt_path is also provided. This enables using
+                different head weights while maintaining the same backbone.
+                Example: "/path/to/head_specialized.ckpt"

596-597: Add error handling for checkpoint loading.

Consider adding try-except blocks when loading checkpoints to handle potential errors gracefully:

  • File not found
  • Invalid checkpoint format
  • Incompatible state dict structure
+def _safe_load_checkpoint(path: str) -> dict:
+    """Safely load a checkpoint file with error handling."""
+    try:
+        return torch.load(path)
+    except FileNotFoundError:
+        raise ValueError(f"Checkpoint file not found: {path}")
+    except Exception as e:
+        raise ValueError(f"Error loading checkpoint {path}: {str(e)}")

Also applies to: 606-607, 610-611

tests/training/test_model_trainer.py (2)

216-218: Simplify the OmegaConf.update call formatting

The OmegaConf.update call can be condensed into a single line for readability, as it doesn't exceed typical line length limits.

Apply this diff to improve formatting:

-OmegaConf.update(
-    config_early_stopping, "trainer_config.lr_scheduler.scheduler", None
-)
+OmegaConf.update(config_early_stopping, "trainer_config.lr_scheduler.scheduler", None)

330-333: Clarify intentional use of invalid scheduler name

The scheduler name "ReduceLR" is likely intended to be invalid to test exception handling. Consider adding a comment to clarify this for future maintainability.

Apply this diff to add a clarifying comment:

+ # Intentionally using an invalid scheduler name to test exception handling
 OmegaConf.update(config, "trainer_config.lr_scheduler.scheduler", "ReduceLR")
 with pytest.raises(ValueError):
     trainer = ModelTrainer(config)
tests/inference/test_predictors.py (3)

3-3: Remove unused import Text

The Text class from the typing module is imported but not used in the code. Removing this unused import will clean up the code.

Apply this diff to remove the unused import:

-from typing import Text
🧰 Tools
🪛 Ruff

3-3: typing.Text imported but unused

Remove unused import: typing.Text

(F401)


692-693: Remove unnecessary debug print statements

The print statements at lines 692-693 and 701-702 seem to be leftover from debugging. Removing these will keep test output clean and focus on relevant information.

Apply this diff to remove the print statements:

-    print(f"head_layer_ckpt: {head_layer_ckpt}")
-    print(model_weights)

Also applies to: 701-702


692-693: Remove unnecessary debug print statements

The print statements at lines 692-693 and 701-702 are likely unintended for committed code. Removing them ensures cleaner test outputs.

Apply this diff:

-    print(f"head_layer_ckpt: {head_layer_ckpt}")
-    print(model_weights)

Also applies to: 701-702

sleap_nn/training/model_trainer.py (7)

376-383: Add docstrings for new parameters in train method

The train method now accepts additional parameters that are not documented. Providing docstrings for these parameters enhances code readability and user understanding.

Apply this diff to add the parameter documentation:

 def train(
     self,
     backbone_trained_ckpts_path: Optional[str] = None,
     head_trained_ckpts_path: Optional[str] = None,
     delete_bin_files_after_training: bool = True,
     train_chunks_dir_path: Optional[str] = None,
     val_chunks_dir_path: Optional[str] = None,
 ):
+    """
+    Initiate the training process.
+
+    Args:
+        backbone_trained_ckpts_path: Path to a trained backbone checkpoint for model initialization.
+        head_trained_ckpts_path: Path to a trained head checkpoint for model initialization.
+        delete_bin_files_after_training: Whether to delete `.bin` files after training. Defaults to True.
+        train_chunks_dir_path: Path to existing training chunks directory. If None, new chunks will be created.
+        val_chunks_dir_path: Path to existing validation chunks directory. If None, new chunks will be created.
+    """
     logger = []

354-358: Add parameter descriptions to the _initialize_model method docstring

The _initialize_model method accepts new parameters that aren't documented. Adding descriptions will improve code clarity.

Apply this diff to add the parameter documentation:

 def _initialize_model(
     self,
     backbone_trained_ckpts_path: Optional[str] = None,
     head_trained_ckpts_path: Optional[str] = None,
 ):
+    """
+    Initialize the model with optional pretrained checkpoints.
+
+    Args:
+        backbone_trained_ckpts_path: Path to a pretrained backbone checkpoint.
+        head_trained_ckpts_path: Path to a pretrained head checkpoint.
+    """
     models = {

528-529: Correct typos in docstring parameter descriptions

In the docstring, replace "ckpts" with "checkpoints" for clarity.

Apply this diff to correct the typos:

             model_type: Type of the model. One of `single_instance`, `centered_instance`, `centroid`, `bottomup`.
-            backbone_trained_ckpts_path: Path to trained ckpts for backbone.
-            head_trained_ckpts_path: Path to trained ckpts for head layer.
+            backbone_trained_ckpts_path: Path to trained checkpoints for backbone.
+            head_trained_ckpts_path: Path to trained checkpoints for head layer.
         """

743-744: Correct typos in docstring parameter descriptions

In the docstring of SingleInstanceModel, replace "ckpts" with "checkpoints" for clarity.

Apply this diff to correct the typos:

             model_type: Type of the model. One of `single_instance`, `centered_instance`, `centroid`, `bottomup`.
-            backbone_trained_ckpts_path: Path to trained ckpts for backbone.
-            head_trained_ckpts_path: Path to trained ckpts for head layer.
+            backbone_trained_ckpts_path: Path to trained checkpoints for backbone.
+            head_trained_ckpts_path: Path to trained checkpoints for head layer.
         """

824-825: Correct typos in docstring parameter descriptions

In the docstring of TopDownCenteredInstanceModel, replace "ckpts" with "checkpoints" for clarity.

Apply this diff to correct the typos:

             model_type: Type of the model. One of `single_instance`, `centered_instance`, `centroid`, `bottomup`.
-            backbone_trained_ckpts_path: Path to trained ckpts for backbone.
-            head_trained_ckpts_path: Path to trained ckpts for head layer.
+            backbone_trained_ckpts_path: Path to trained checkpoints for backbone.
+            head_trained_ckpts_path: Path to trained checkpoints for head layer.
         """

905-906: Correct typos in docstring parameter descriptions

In the docstring of CentroidModel, replace "ckpts" with "checkpoints" for clarity.

Apply this diff to correct the typos:

             model_type: Type of the model. One of `single_instance`, `centered_instance`, `centroid`, `bottomup`.
-            backbone_trained_ckpts_path: Path to trained ckpts for backbone.
-            head_trained_ckpts_path: Path to trained ckpts for head layer.
+            backbone_trained_ckpts_path: Path to trained checkpoints for backbone.
+            head_trained_ckpts_path: Path to trained checkpoints for head layer.
         """

986-987: Correct typos in docstring parameter descriptions

In the docstring of BottomUpModel, replace "ckpts" with "checkpoints" for clarity.

Apply this diff to correct the typos:

             model_type: Type of the model. One of `single_instance`, `centered_instance`, `centroid`, `bottomup`.
-            backbone_trained_ckpts_path: Path to trained ckpts for backbone.
-            head_trained_ckpts_path: Path to trained ckpts for head layer.
+            backbone_trained_ckpts_path: Path to trained checkpoints for backbone.
+            head_trained_ckpts_path: Path to trained checkpoints for head layer.
         """
🛑 Comments failed to post (20)
tests/assets/minimal_instance/training_config.yaml (1)

82-82: 💡 Codebase verification

Based on the search results, I can now generate a final response as I found the implementation details in sleap_nn/training/model_trainer.py:

Clarify bin_files_path behavior in configuration documentation

The implementation shows that bin_files_path is used as follows:

  • When null (default), it uses the input directory path
  • When set, creates a timestamped subdirectory under the specified path ({bin_files_path}/chunks_{timestamp})
  • The path can be absolute or relative, and proper directory creation with permissions is handled

Please add these details to the configuration documentation to help users understand:

  • The default behavior (null = use input directory)
  • The automatic timestamped subdirectory creation
  • That both absolute and relative paths are supported
🔗 Analysis chain

Enhance bin_files_path configuration.

The new bin_files_path parameter is crucial for the .bin file reuse feature, but needs more clarity:

  1. Document the expected path format (absolute/relative)
  2. Clarify what happens when the value is null
  3. Consider adding path validation

Let's check if there's any documentation or validation for this parameter:


Let me try another search strategy to find any documentation or validation related to bin_files_path:

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Search for documentation or validation related to bin_files_path
rg -A 5 "bin_files_path"

Length of output: 195


Script:

#!/bin/bash
# Search for bin_files_path in Python files and YAML files
fd -e py -e yaml --exec rg -l "bin_files_path" {}

# Get content of files containing bin_files_path
fd -e py -e yaml --exec rg -B 5 -A 5 "bin_files_path" {}

Length of output: 12827

initial_config.yaml (3)

1-104: ⚠️ Potential issue

Fix newline characters in the file.

The file uses incorrect newline characters which can cause issues across different operating systems.

Convert all line endings to Unix-style (LF) newlines. You can do this using:

  • Git: git config --global core.autocrlf input
  • Editor: Most modern editors can handle this automatically
  • Command line: dos2unix initial_config.yaml
🧰 Tools
🪛 yamllint

[error] 1-1: wrong new line character: expected \n

(new-lines)


84-84: ⚠️ Potential issue

Remove API key field from configuration.

Storing API keys in configuration files, even if empty, is a security risk as it might encourage users to commit sensitive information.

Consider:

  1. Removing the API key field entirely and using environment variables
  2. Adding a warning comment about not committing API keys
-    api_key: ''
+    # Configure API key using WANDB_API_KEY environment variable
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

    # Configure API key using WANDB_API_KEY environment variable

3-4: ⚠️ Potential issue

Replace hardcoded Windows paths with platform-independent paths.

The current paths have several issues:

  1. Windows-specific backslashes make it non-portable
  2. Absolute paths are not suitable for version control
  3. Using the same file for both training and validation could lead to overfitting

Consider using:

  1. Relative paths with forward slashes
  2. Different datasets for training and validation
-  train_labels_path: C:\Users\TalmoLab\Desktop\Divya\sleap-nn\tests\assets/minimal_instance.pkg.slp
-  val_labels_path: C:\Users\TalmoLab\Desktop\Divya\sleap-nn\tests\assets/minimal_instance.pkg.slp
+  train_labels_path: data/train/minimal_instance.pkg.slp
+  val_labels_path: data/val/minimal_instance.pkg.slp
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

  train_labels_path: data/train/minimal_instance.pkg.slp
  val_labels_path: data/val/minimal_instance.pkg.slp
docs/config.md (1)

178-188: ⚠️ Potential issue

Fix markdown formatting issues.

There are several formatting inconsistencies in the scheduler configuration section:

  1. List indentation is inconsistent
  2. Unnecessary spaces in emphasis markers

Apply these formatting fixes:

-        - `scheduler`: (str) Name of the scheduler to use. Valid schedulers: `"StepLR"`, `"ReduceLROnPlateau"`.
-        - `step_lr`:
-            - `step_size`: (int) Period of learning rate decay. If `step_size`=10, then every 10 epochs, learning rate will be reduced by a factor of `gamma`.
+    - `scheduler`: (str) Name of the scheduler to use. Valid schedulers: `"StepLR"`, `"ReduceLROnPlateau"`.
+    - `step_lr`:
+      - `step_size`: (int) Period of learning rate decay. If `step_size`=10, then every 10 epochs, learning rate will be reduced by a factor of `gamma`.

Committable suggestion skipped: line range outside the PR's diff.

🧰 Tools
🪛 LanguageTool

[uncategorized] ~179-~179: Loose punctuation mark.
Context: ...ReduceLROnPlateau". - step_lr: - step_size`: (int) Period...

(UNLIKELY_OPENING_PUNCTUATION)


[uncategorized] ~182-~182: Loose punctuation mark.
Context: ...*: 0.1. - reduce_lr_on_plateau: - threshold: (float) Thre...

(UNLIKELY_OPENING_PUNCTUATION)

🪛 Markdownlint

178-178: Expected: 4; Actual: 8
Unordered list indentation

(MD007, ul-indent)


179-179: Expected: 4; Actual: 8
Unordered list indentation

(MD007, ul-indent)


180-180: Expected: 6; Actual: 12
Unordered list indentation

(MD007, ul-indent)


181-181: Expected: 6; Actual: 12
Unordered list indentation

(MD007, ul-indent)


182-182: Expected: 4; Actual: 8
Unordered list indentation

(MD007, ul-indent)


183-183: Expected: 6; Actual: 12
Unordered list indentation

(MD007, ul-indent)


184-184: Expected: 6; Actual: 12
Unordered list indentation

(MD007, ul-indent)


185-185: Expected: 6; Actual: 12
Unordered list indentation

(MD007, ul-indent)


186-186: Expected: 6; Actual: 12
Unordered list indentation

(MD007, ul-indent)


187-187: Expected: 6; Actual: 12
Unordered list indentation

(MD007, ul-indent)


188-188: Expected: 6; Actual: 12
Unordered list indentation

(MD007, ul-indent)


184-184: null
Spaces inside emphasis markers

(MD037, no-space-in-emphasis)


184-184: null
Spaces inside emphasis markers

(MD037, no-space-in-emphasis)

sleap_nn/inference/predictors.py (3)

594-617: 🛠️ Refactor suggestion

Refactor duplicated weight loading logic.

The weight loading logic is duplicated between the centroid and confmap models. Consider extracting this into a helper function to improve maintainability and reduce code duplication.

+def _load_model_weights(model, backbone_ckpt_path: Optional[str], head_ckpt_path: Optional[str]) -> None:
+    """Load backbone and head weights into the model.
+    
+    Args:
+        model: The model to load weights into
+        backbone_ckpt_path: Path to backbone checkpoint
+        head_ckpt_path: Path to head checkpoint
+    """
+    if backbone_ckpt_path is not None and head_ckpt_path is not None:
+        print(f"Loading backbone weights from `{backbone_ckpt_path}` ...")
+        ckpt = torch.load(backbone_ckpt_path)
+        ckpt["state_dict"] = {
+            k: v for k, v in ckpt["state_dict"].items()
+            if ".backbone" in k
+        }
+        model.load_state_dict(ckpt["state_dict"], strict=False)
+    elif backbone_ckpt_path is not None:
+        print(f"Loading weights from `{backbone_ckpt_path}` ...")
+        ckpt = torch.load(backbone_ckpt_path)
+        model.load_state_dict(ckpt["state_dict"], strict=False)
+
+    if head_ckpt_path is not None:
+        print(f"Loading head weights from `{head_ckpt_path}` ...")
+        ckpt = torch.load(head_ckpt_path)
+        ckpt["state_dict"] = {
+            k: v for k, v in ckpt["state_dict"].items()
+            if ".head_layers" in k
+        }
+        model.load_state_dict(ckpt["state_dict"], strict=False)

Then use this helper function:

-        if backbone_ckpt_path is not None and head_ckpt_path is not None:
-            print(f"Loading backbone weights from `{backbone_ckpt_path}` ...")
-            ckpt = torch.load(backbone_ckpt_path)
-            ckpt["state_dict"] = {
-                k: ckpt["state_dict"][k]
-                for k in ckpt["state_dict"].keys()
-                if ".backbone" in k
-            }
-            centroid_model.load_state_dict(ckpt["state_dict"], strict=False)
-
-        elif backbone_ckpt_path is not None:
-            print(f"Loading weights from `{backbone_ckpt_path}` ...")
-            ckpt = torch.load(backbone_ckpt_path)
-            centroid_model.load_state_dict(ckpt["state_dict"], strict=False)
-
-        if head_ckpt_path is not None:
-            print(f"Loading head weights from `{head_ckpt_path}` ...")
-            ckpt = torch.load(head_ckpt_path)
-            ckpt["state_dict"] = {
-                k: ckpt["state_dict"][k]
-                for k in ckpt["state_dict"].keys()
-                if ".head_layers" in k
-            }
-            centroid_model.load_state_dict(ckpt["state_dict"], strict=False)
+        _load_model_weights(centroid_model, backbone_ckpt_path, head_ckpt_path)

Also applies to: 636-659

🧰 Tools
🪛 Ruff

599-599: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


614-614: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


1349-1372: 🛠️ Refactor suggestion

Apply consistent improvements across all predictor classes.

The same improvements suggested for other predictor classes should be applied here:

  1. Simplify dictionary key checks
  2. Use the shared weight loading helper function

This ensures consistency across the codebase and reduces maintenance overhead.

🧰 Tools
🪛 Ruff

1354-1354: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


1369-1369: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


988-1011: 🛠️ Refactor suggestion

Simplify dictionary key checks and apply consistent weight loading pattern.

  1. Simplify the dictionary key checks by removing unnecessary .keys() calls
  2. Use the same helper function suggested for TopDownPredictor to handle weight loading
-                k: ckpt["state_dict"][k]
-                for k in ckpt["state_dict"].keys()
+                k: v for k, v in ckpt["state_dict"].items()
                 if ".backbone" in k

Apply the same _load_model_weights helper function here to maintain consistency and reduce code duplication.

Committable suggestion skipped: line range outside the PR's diff.

🧰 Tools
🪛 Ruff

993-993: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


1008-1008: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)

tests/training/test_model_trainer.py (2)

338-376: 🛠️ Refactor suggestion

Avoid hardcoded indices when accessing model parameters

Directly accessing model parameters with hardcoded indices like [0, 0, :] can lead to maintenance issues if the model architecture changes. Consider using parameter names or iterating over the parameters for a more robust approach.


379-424: ⚠️ Potential issue

Add assertions to verify reuse of .bin files

The test test_reuse_bin_files sets up for reusing .bin files but lacks assertions to confirm that the files are indeed reused. Adding assertions will strengthen the test by ensuring the files are not regenerated.

Consider adding these assertions:

assert os.path.exists(trainer1.train_input_dir)
assert os.path.exists(trainer1.val_input_dir)
assert os.path.exists(trainer2.train_input_dir)
assert os.path.exists(trainer2.val_input_dir)
assert trainer1.train_input_dir == trainer2.train_input_dir
assert trainer1.val_input_dir == trainer2.val_input_dir
tests/inference/test_predictors.py (5)

188-218: 🛠️ Refactor suggestion

Refactor duplicate test code into a helper function

The code block from lines 188 to 218 appears to be duplicated in other test functions (e.g., lines 447-496 and 670-704). Refactoring this repeated code into a helper function would enhance maintainability and reduce code duplication.


447-496: 🛠️ Refactor suggestion

Refactor duplicate test code into a helper function

The code from lines 447 to 496 is similar to code in other test sections. Refactoring this repeated logic into a shared helper function can improve code maintainability and readability.


670-704: 🛠️ Refactor suggestion

Refactor duplicate test code into a helper function

The code block from lines 670 to 704 is repeated in other tests. Refactoring into a helper function will reduce redundancy and simplify future maintenance.


491-491: 🛠️ Refactor suggestion

Use np.testing.assert_allclose for numerical assertions

For consistency and better error handling, use np.testing.assert_allclose instead of manual assertions when comparing numerical arrays.

Apply this diff:

-assert np.all(np.abs(head_layer_ckpt - model_weights) < 1e-6)
+np.testing.assert_allclose(head_layer_ckpt, model_weights, atol=1e-6)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

        np.testing.assert_allclose(head_layer_ckpt, model_weights, atol=1e-6)

218-218: 🛠️ Refactor suggestion

Use np.testing.assert_allclose for numerical assertions

Instead of manually checking numerical closeness with assert np.all(np.abs(...)) < tolerance, consider using np.testing.assert_allclose for better readability and error messages.

Apply this diff to improve the assertion:

-assert np.all(np.abs(head_layer_ckpt - model_weights) < 1e-6)
+np.testing.assert_allclose(head_layer_ckpt, model_weights, atol=1e-6)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

    np.testing.assert_allclose(head_layer_ckpt, model_weights, atol=1e-6)
sleap_nn/training/model_trainer.py (5)

380-380: ⚠️ Potential issue

Clarify default behavior of delete_bin_files_after_training parameter

The parameter delete_bin_files_after_training defaults to True, meaning .bin files will be deleted after training. This could be unexpected for users who want to reuse these files. Consider setting the default to False or clearly documenting this behavior.


240-240: ⚠️ Potential issue

Use raise ... from e when re-raising exceptions to preserve traceback

When re-raising an exception, include the original exception using from e to maintain the traceback.

Apply this diff to modify the raise statement:

-            raise Exception(f"Error while creating the `.bin` files... {e}")
+            raise Exception("Error while creating the `.bin` files...") from e
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

                raise Exception("Error while creating the `.bin` files...") from e
🧰 Tools
🪛 Ruff

240-240: Within an except clause, raise exceptions with raise ... from err or raise ... from None to distinguish them from errors in exception handling

(B904)


224-226: ⚠️ Potential issue

Use raise ... from e when re-raising exceptions to preserve traceback

When raising a new exception within an except block, it's best practice to use raise ... from e to maintain the original traceback.

Apply this diff to modify the raise statement:

-            raise OSError(
-                f"Cannot create a new folder in {self.bin_files_path}. Check the permissions to the given Checkpoint directory. \n {e}"
-            )
+            raise OSError(
+                f"Cannot create a new folder in {self.bin_files_path}. Check the permissions to the given Checkpoint directory."
+            ) from e
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

                        raise OSError(
                            f"Cannot create a new folder in {self.bin_files_path}. Check the permissions to the given Checkpoint directory."
                        ) from e
🧰 Tools
🪛 Ruff

224-226: Within an except clause, raise exceptions with raise ... from err or raise ... from None to distinguish them from errors in exception handling

(B904)


613-613: ⚠️ Potential issue

Simplify dictionary key iteration by removing .keys()

When iterating over a dictionary's keys, you can omit .keys() for simplicity.

Apply this diff to simplify the comprehension:

     ckpt["state_dict"] = {
         k: ckpt["state_dict"][k]
-        for k in ckpt["state_dict"].keys()
+        for k in ckpt["state_dict"]
         if ".backbone" in k
     }

Committable suggestion skipped: line range outside the PR's diff.

🧰 Tools
🪛 Ruff

613-613: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


624-624: ⚠️ Potential issue

Simplify dictionary key iteration by removing .keys()

When iterating over a dictionary's keys, you can omit .keys() for simplicity.

Apply this diff to simplify the comprehension:

         ckpt["state_dict"] = {
             k: ckpt["state_dict"][k]
-            for k in ckpt["state_dict"].keys()
+            for k in ckpt["state_dict"]
             if ".head_layers" in k
         }

Committable suggestion skipped: line range outside the PR's diff.

🧰 Tools
🪛 Ruff

624-624: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)

Copy link

codecov bot commented Nov 2, 2024

Codecov Report

Attention: Patch coverage is 86.95652% with 6 lines in your changes missing coverage. Please review.

Project coverage is 97.51%. Comparing base (1234f51) to head (7c56e6e).

Files with missing lines Patch % Lines
sleap_nn/training/model_trainer.py 88.37% 5 Missing ⚠️
sleap_nn/inference/utils.py 66.66% 1 Missing ⚠️
Additional details and impacted files
@@                        Coverage Diff                         @@
##           divya/load-head-ckpt-inference     #116      +/-   ##
==================================================================
+ Coverage                           97.34%   97.51%   +0.17%     
==================================================================
  Files                                  38       38              
  Lines                                3761     3778      +17     
==================================================================
+ Hits                                 3661     3684      +23     
+ Misses                                100       94       -6     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant