Skip to content

Commit

Permalink
add docs and minor fix
Browse files Browse the repository at this point in the history
  • Loading branch information
JSabadin committed Nov 16, 2024
1 parent dc1d468 commit 9dfd8da
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 3 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,7 @@ model.tune()
- [**Callbacks**](https://github.com/luxonis/luxonis-train/blob/main/luxonis_train/callbacks/README.md): Allow custom code to be executed at different stages of training.
- [**Optimizers**](https://github.com/luxonis/luxonis-train/blob/main/configs/README.md#optimizer): Control how the model's weights are updated.
- [**Schedulers**](https://github.com/luxonis/luxonis-train/blob/main/configs/README.md#scheduler): Adjust the learning rate during training.
- [**Training Strategy**](https://github.com/luxonis/luxonis-train/blob/main/configs/README.md#training-strategy): Specify a custom combination of optimizer and scheduler to tailor the training process for specific use cases.

**Creating Custom Components:**

Expand All @@ -581,6 +582,7 @@ Registered components can be referenced in the config file. Custom components ne
- **Callbacks** - [`lightning.pytorch.callbacks.Callback`](https://lightning.ai/docs/pytorch/stable/extensions/callbacks.html), requires manual registration to the `CALLBACKS` registry
- **Optimizers** - [`torch.optim.Optimizer`](https://pytorch.org/docs/stable/optim.html#torch.optim.Optimizer), requires manual registration to the `OPTIMIZERS` registry
- **Schedulers** - [`torch.optim.lr_scheduler.LRScheduler`](https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate), requires manual registration to the `SCHEDULERS` registry
- **Training Strategy** - [`BaseTrainingStrategy`](https://github.com/luxonis/luxonis-train/blob/main/luxonis_train/strategies/base_strategy.py)

**Examples:**

Expand Down
4 changes: 2 additions & 2 deletions luxonis_train/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ class SchedulerConfig(BaseModelExtraForbid):


class TrainingStrategyConfig(BaseModelExtraForbid):
name: str = "TripleLRSGDStrategy"
name: str
params: Params = {}


Expand Down Expand Up @@ -387,7 +387,7 @@ class TrainerConfig(BaseModelExtraForbid):

optimizer: OptimizerConfig = OptimizerConfig()
scheduler: SchedulerConfig = SchedulerConfig()
training_strategy: TrainingStrategyConfig = TrainingStrategyConfig()
training_strategy: TrainingStrategyConfig | None = None

@model_validator(mode="after")
def validate_deterministic(self) -> Self:
Expand Down
10 changes: 9 additions & 1 deletion luxonis_train/models/luxonis_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,15 @@ def __init__(

self.load_checkpoint(self.cfg.model.weights)

if self.cfg.trainer.training_strategy.params:
if self.cfg.trainer.training_strategy is not None:
if self.cfg.trainer.optimizer is not None:

Check warning on line 277 in luxonis_train/models/luxonis_lightning.py

View check run for this annotation

Codecov / codecov/patch

luxonis_train/models/luxonis_lightning.py#L277

Added line #L277 was not covered by tests
logger.warning(
"Training strategy is active; the specified optimizer will be ignored."
)
if self.cfg.trainer.scheduler is not None:

Check warning on line 281 in luxonis_train/models/luxonis_lightning.py

View check run for this annotation

Codecov / codecov/patch

luxonis_train/models/luxonis_lightning.py#L281

Added line #L281 was not covered by tests
logger.warning(
"Training strategy is active; the specified scheduler will be ignored."
)
self.training_strategy = STRATEGIES.get(

Check warning on line 285 in luxonis_train/models/luxonis_lightning.py

View check run for this annotation

Codecov / codecov/patch

luxonis_train/models/luxonis_lightning.py#L285

Added line #L285 was not covered by tests
self.cfg.trainer.training_strategy.name
)(
Expand Down

0 comments on commit 9dfd8da

Please sign in to comment.