Skip to content

Commit

Permalink
code refactor; added training strategies
Browse files Browse the repository at this point in the history
  • Loading branch information
JSabadin committed Nov 15, 2024
1 parent 00a2bd3 commit 147c1ee
Show file tree
Hide file tree
Showing 17 changed files with 297 additions and 218 deletions.
1 change: 1 addition & 0 deletions luxonis_train/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .nodes import *
from .optimizers import *
from .schedulers import *
from .strategies import *
from .utils import *
except ImportError as e:
warnings.warn(
Expand Down
3 changes: 3 additions & 0 deletions luxonis_train/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from .metadata_logger import MetadataLogger
from .module_freezer import ModuleFreezer
from .test_on_train_end import TestOnTrainEnd
from .training_manager import TrainingManager
from .upload_checkpoint import UploadCheckpoint

CALLBACKS.register_module(module=EarlyStopping)
Expand All @@ -38,6 +39,7 @@
CALLBACKS.register_module(module=ModelPruning)
CALLBACKS.register_module(module=GradCamCallback)
CALLBACKS.register_module(module=EMACallback)
CALLBACKS.register_module(module=TrainingManager)


__all__ = [
Expand All @@ -53,4 +55,5 @@
"GPUStatsMonitor",
"GradCamCallback",
"EMACallback",
"TrainingManager",
]
28 changes: 28 additions & 0 deletions luxonis_train/callbacks/training_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import pytorch_lightning as pl

from luxonis_train.strategies.base_strategy import BaseTrainingStrategy


class TrainingManager(pl.Callback):
def __init__(self, strategy: BaseTrainingStrategy | None = None):
"""Training manager callback that updates the parameters of the
training strategy.
@type strategy: BaseTrainingStrategy
@param strategy: The strategy to be used.
"""
self.strategy = strategy

def on_after_backward(
self, trainer: pl.Trainer, pl_module: pl.LightningModule
):
"""PyTorch Lightning hook that is called after the backward
pass.
@type trainer: pl.Trainer
@param trainer: The trainer object.
@type pl_module: pl.LightningModule
@param pl_module: The pl_module object.
"""
if self.strategy is not None:
self.strategy.update_parameters(pl_module)
7 changes: 6 additions & 1 deletion luxonis_train/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,11 @@ class SchedulerConfig(BaseModelExtraForbid):
params: Params = {}


class TrainingStrategyConfig(BaseModelExtraForbid):
name: str = "TripleLRSGDStrategy"
params: Params = {}


class TrainerConfig(BaseModelExtraForbid):
preprocessing: PreprocessingConfig = PreprocessingConfig()
use_rich_progress_bar: bool = True
Expand All @@ -355,7 +360,6 @@ class TrainerConfig(BaseModelExtraForbid):
profiler: Literal["simple", "advanced"] | None = None
matmul_precision: Literal["medium", "high", "highest"] | None = None
verbose: bool = True
apply_custom_lr: bool = False

seed: int | None = None
n_validation_batches: PositiveInt | None = None
Expand Down Expand Up @@ -383,6 +387,7 @@ class TrainerConfig(BaseModelExtraForbid):

optimizer: OptimizerConfig = OptimizerConfig()
scheduler: SchedulerConfig = SchedulerConfig()
training_strategy: TrainingStrategyConfig = TrainingStrategyConfig()

@model_validator(mode="after")
def validate_deterministic(self) -> Self:
Expand Down
69 changes: 26 additions & 43 deletions luxonis_train/models/luxonis_lightning.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import math
from collections import defaultdict
from collections.abc import Mapping
from logging import getLogger
Expand Down Expand Up @@ -26,7 +25,11 @@
combine_visualizations,
get_denormalized_images,
)
from luxonis_train.callbacks import BaseLuxonisProgressBar, ModuleFreezer
from luxonis_train.callbacks import (
BaseLuxonisProgressBar,
ModuleFreezer,
TrainingManager,
)
from luxonis_train.config import AttachedModuleConfig, Config
from luxonis_train.nodes import BaseNode
from luxonis_train.utils import (
Expand All @@ -43,6 +46,7 @@
CALLBACKS,
OPTIMIZERS,
SCHEDULERS,
STRATEGIES,
Registry,
)

Expand Down Expand Up @@ -269,6 +273,16 @@ def __init__(

self.load_checkpoint(self.cfg.model.weights)

if self.cfg.trainer.training_strategy.params:
self.training_strategy = STRATEGIES.get(
self.cfg.trainer.training_strategy.name
)(
pl_module=self,
params=self.cfg.trainer.training_strategy.params,
)
else:
self.training_strategy = None

@property
def core(self) -> "luxonis_train.core.LuxonisModel":
"""Returns the core model."""
Expand Down Expand Up @@ -850,6 +864,9 @@ def configure_callbacks(self) -> list[pl.Callback]:
CALLBACKS.get(callback.name)(**callback.params)
)

if self.training_strategy is not None:
callbacks.append(TrainingManager(strategy=self.training_strategy))

Check failure on line 868 in luxonis_train/models/luxonis_lightning.py

View workflow job for this annotation

GitHub Actions / type-check

Argument of type "TrainingManager" cannot be assigned to parameter "object" of type "Callback" in function "append"   "TrainingManager" is not assignable to "Callback" (reportArgumentType)

return callbacks

def configure_optimizers(
Expand All @@ -858,45 +875,17 @@ def configure_optimizers(
list[torch.optim.Optimizer],
list[torch.optim.lr_scheduler.LRScheduler],
]:
"""Configures model optimizers and schedulers with optional
custom learning rates and warm-up logic."""
"""Configures model optimizers and schedulers."""
if self.training_strategy is not None:
return self.training_strategy.configure_optimizers()

cfg_optimizer = self.cfg.trainer.optimizer
cfg_scheduler = self.cfg.trainer.scheduler

if self.cfg.trainer.apply_custom_lr:
assert (
cfg_optimizer.name == "TripleLRSGD"
), "Custom learning rate is only supported for TripleLRSGD optimizer."
assert (
cfg_scheduler.name == "TripleLRScheduler"
), "Custom learning rate is only supported for TripleLRScheduler scheduler."

max_stepnum = math.ceil(
len(self._core.loaders["train"]) / self.cfg.trainer.batch_size
)
custom_optimizer = OPTIMIZERS.get(cfg_optimizer.name)(
self, cfg_optimizer.params
)
optimizer = custom_optimizer.create_optimizer()

custom_scheduler = SCHEDULERS.get(cfg_scheduler.name)(
optimizer,
cfg_scheduler.params,
self.cfg.trainer.epochs,
max_stepnum,
)
scheduler = custom_scheduler.create_scheduler()

self.custom_scheduler = custom_scheduler

return [optimizer], [scheduler]

else:
optim_params = cfg_optimizer.params | {
"params": filter(lambda p: p.requires_grad, self.parameters()),
}
optimizer = OPTIMIZERS.get(cfg_optimizer.name)(**optim_params)
optim_params = cfg_optimizer.params | {
"params": filter(lambda p: p.requires_grad, self.parameters()),
}
optimizer = OPTIMIZERS.get(cfg_optimizer.name)(**optim_params)

def get_scheduler(scheduler_cfg, optimizer):
scheduler_class = SCHEDULERS.get(
Expand Down Expand Up @@ -927,12 +916,6 @@ def get_scheduler(scheduler_cfg, optimizer):

return [optimizer], [scheduler]

def on_after_backward(self):
"""Custom logic to adjust learning rates and momentum after
loss.backward."""
if self.cfg.trainer.apply_custom_lr:
self.custom_scheduler.update_learning_rate(self.current_epoch)

def load_checkpoint(self, path: str | Path | None) -> None:
"""Loads checkpoint weights from provided path.
Expand Down
11 changes: 6 additions & 5 deletions luxonis_train/nodes/backbones/efficientrep/efficientrep.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,13 +132,14 @@ def __init__(

def initialize_weights(self):
for m in self.modules():
t = type(m)
if t is nn.Conv2d:
if isinstance(m, nn.Conv2d):
pass
elif t is nn.BatchNorm2d:
m.eps = 1e-3
elif isinstance(m, nn.BatchNorm2d):
m.eps = 0.001
m.momentum = 0.03
elif t in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]:
elif isinstance(
m, (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU)
):
m.inplace = True

def set_export_mode(self, mode: bool = True) -> None:
Expand Down
11 changes: 6 additions & 5 deletions luxonis_train/nodes/blocks/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,14 @@ def __init__(self, n_classes: int, in_channels: int):

def initialize_weights(self):
for m in self.modules():
t = type(m)
if t is nn.Conv2d:
if isinstance(m, nn.Conv2d):
pass
elif t is nn.BatchNorm2d:
m.eps = 1e-3
elif isinstance(m, nn.BatchNorm2d):
m.eps = 0.001
m.momentum = 0.03
elif t in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]:
elif isinstance(
m, (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU)
):
m.inplace = True

def forward(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor]:
Expand Down
11 changes: 6 additions & 5 deletions luxonis_train/nodes/heads/efficient_bbox_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,13 +105,14 @@ def __init__(

def initialize_weights(self):
for m in self.modules():
t = type(m)
if t is nn.Conv2d:
if isinstance(m, nn.Conv2d):
pass
elif t is nn.BatchNorm2d:
m.eps = 1e-3
elif isinstance(m, nn.BatchNorm2d):
m.eps = 0.001
m.momentum = 0.03
elif t in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]:
elif isinstance(
m, (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU)
):
m.inplace = True

def forward(
Expand Down
11 changes: 6 additions & 5 deletions luxonis_train/nodes/necks/reppan_neck/reppan_neck.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,13 +172,14 @@ def __init__(

def initialize_weights(self):
for m in self.modules():
t = type(m)
if t is nn.Conv2d:
if isinstance(m, nn.Conv2d):
pass
elif t is nn.BatchNorm2d:
m.eps = 1e-3
elif isinstance(m, nn.BatchNorm2d):
m.eps = 0.001
m.momentum = 0.03
elif t in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]:
elif isinstance(
m, (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU)
):
m.inplace = True

def forward(self, inputs: list[Tensor]) -> list[Tensor]:
Expand Down
58 changes: 0 additions & 58 deletions luxonis_train/optimizers/custom_optimizers.py

This file was deleted.

3 changes: 0 additions & 3 deletions luxonis_train/optimizers/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

from luxonis_train.utils.registry import OPTIMIZERS

from .custom_optimizers import TripleLRSGD

for optimizer in [
optim.Adadelta,
optim.Adagrad,
Expand All @@ -17,6 +15,5 @@
optim.RAdam,
optim.RMSprop,
optim.SGD,
TripleLRSGD,
]:
OPTIMIZERS.register_module(module=optimizer)
Loading

0 comments on commit 147c1ee

Please sign in to comment.