Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to load trained weights for Head layers #114

Open
wants to merge 1 commit into
base: divya/add-config-wandb
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 83 additions & 24 deletions sleap_nn/training/model_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,11 @@ def run_subprocess():
def _set_wandb(self):
wandb.login(key=self.config.trainer_config.wandb.api_key)

def _initialize_model(self, trained_ckpts_path: Optional[str] = None):
def _initialize_model(
self,
backbone_trained_ckpts_path: Optional[str] = None,
head_trained_ckpts_path: Optional[str] = None,
):
models = {
"single_instance": SingleInstanceModel,
"centered_instance": TopDownCenteredInstanceModel,
Expand All @@ -340,13 +344,18 @@ def _initialize_model(self, trained_ckpts_path: Optional[str] = None):
self.config,
self.skeletons,
self.model_type,
trained_ckpts_path=trained_ckpts_path,
backbone_trained_ckpts_path=backbone_trained_ckpts_path,
head_trained_ckpts_path=head_trained_ckpts_path,
)

def _get_param_count(self):
return sum(p.numel() for p in self.model.parameters())

def train(self, trained_ckpts_path: Optional[str] = None):
def train(
self,
backbone_trained_ckpts_path: Optional[str] = None,
head_trained_ckpts_path: Optional[str] = None,
):
"""Initiate the training by calling the fit method of Trainer."""
logger = []

Expand Down Expand Up @@ -398,7 +407,10 @@ def train(self, trained_ckpts_path: Optional[str] = None):
# save the configs as yaml in the checkpoint dir
self.config.trainer_config.wandb.api_key = ""

self._initialize_model(trained_ckpts_path)
self._initialize_model(
backbone_trained_ckpts_path=backbone_trained_ckpts_path,
head_trained_ckpts_path=head_trained_ckpts_path,
)
total_params = self._get_param_count()
self.config.model_config.total_params = total_params
# save the configs as yaml in the checkpoint dir
Expand Down Expand Up @@ -482,15 +494,17 @@ class TrainingModel(L.LightningModule):
(iii) trainer_config: trainer configs like accelerator, optimiser params.
skeletons: List of `sio.Skeleton` objects from the input `.slp` file.
model_type: Type of the model. One of `single_instance`, `centered_instance`, `centroid`, `bottomup`.
trained_ckpts_path: Path to trained ckpts.
backbone_trained_ckpts_path: Path to trained ckpts for backbone.
head_trained_ckpts_path: Path to trained ckpts for head layer.
"""

def __init__(
self,
config: OmegaConf,
skeletons: Optional[List[sio.Skeleton]],
model_type: str,
trained_ckpts_path: str = None,
backbone_trained_ckpts_path: Optional[str] = None,
head_trained_ckpts_path: Optional[str] = None,
):
"""Initialise the configs and the model."""
super().__init__()
Expand All @@ -500,7 +514,8 @@ def __init__(
self.trainer_config = self.config.trainer_config
self.data_config = self.config.data_config
self.model_type = model_type
self.trained_ckpts_path = trained_ckpts_path
self.backbone_trained_ckpts_path = backbone_trained_ckpts_path
self.head_trained_ckpts_path = head_trained_ckpts_path
self.input_expand_channels = self.model_config.backbone_config.in_channels
if self.model_config.pre_trained_weights: # only for swint and convnext
ckpt = eval(self.model_config.pre_trained_weights).DEFAULT.get_state_dict(
Expand Down Expand Up @@ -557,18 +572,30 @@ def __init__(
if self.model_config.pre_trained_weights:
self.model.backbone.enc.load_state_dict(ckpt, strict=False)

# Initializing model (encoder + decoder) with trained ckpts
# TODO: Handling different input channels
if trained_ckpts_path is not None:
print(f"Loading weights from `{trained_ckpts_path}` ...")
ckpt = torch.load(trained_ckpts_path)
# Initializing backbone (encoder + decoder) with trained ckpts
if backbone_trained_ckpts_path is not None:
print(f"Loading backbone weights from `{backbone_trained_ckpts_path}` ...")
ckpt = torch.load(backbone_trained_ckpts_path)
ckpt["state_dict"] = {
k: ckpt["state_dict"][k]
for k in ckpt["state_dict"].keys()
if ".head" not in k
if ".backbone" in k
}
self.load_state_dict(ckpt["state_dict"], strict=False)

# Initializing head layers with trained ckpts.
if head_trained_ckpts_path is not None:
print(f"Loading head weights from `{head_trained_ckpts_path}` ...")
ckpt = torch.load(head_trained_ckpts_path)
ckpt["state_dict"] = {
k: ckpt["state_dict"][k]
for k in ckpt["state_dict"].keys()
if ".head_layers" in k
}
print(f"from main code: {ckpt['state_dict'].keys()}")
self.load_state_dict(ckpt["state_dict"], strict=False)

def forward(self, img):
"""Forward pass of the model."""
pass
Expand Down Expand Up @@ -683,7 +710,8 @@ class SingleInstanceModel(TrainingModel):
(iii) trainer_config: trainer configs like accelerator, optimiser params.
skeletons: List of `sio.Skeleton` objects from the input `.slp` file.
model_type: Type of the model. One of `single_instance`, `centered_instance`, `centroid`, `bottomup`.
trained_ckpts_path: Path to trained ckpts.
backbone_trained_ckpts_path: Path to trained ckpts for backbone.
head_trained_ckpts_path: Path to trained ckpts for head layer.

"""

Expand All @@ -692,10 +720,17 @@ def __init__(
config: OmegaConf,
skeletons: Optional[List[sio.Skeleton]],
model_type: str,
trained_ckpts_path: str = None,
backbone_trained_ckpts_path: Optional[str] = None,
head_trained_ckpts_path: Optional[str] = None,
):
"""Initialise the configs and the model."""
super().__init__(config, skeletons, model_type, trained_ckpts_path)
super().__init__(
config,
skeletons,
model_type,
backbone_trained_ckpts_path,
head_trained_ckpts_path,
)

def forward(self, img):
"""Forward pass of the model."""
Expand Down Expand Up @@ -756,7 +791,8 @@ class TopDownCenteredInstanceModel(TrainingModel):
(iii) trainer_config: trainer configs like accelerator, optimiser params.
skeletons: List of `sio.Skeleton` objects from the input `.slp` file.
model_type: Type of the model. One of `single_instance`, `centered_instance`, `centroid`, `bottomup`.
trained_ckpts_path: Path to trained ckpts.
backbone_trained_ckpts_path: Path to trained ckpts for backbone.
head_trained_ckpts_path: Path to trained ckpts for head layer.

"""

Expand All @@ -765,10 +801,17 @@ def __init__(
config: OmegaConf,
skeletons: Optional[List[sio.Skeleton]],
model_type: str,
trained_ckpts_path: str = None,
backbone_trained_ckpts_path: Optional[str] = None,
head_trained_ckpts_path: Optional[str] = None,
):
"""Initialise the configs and the model."""
super().__init__(config, skeletons, model_type, trained_ckpts_path)
super().__init__(
config,
skeletons,
model_type,
backbone_trained_ckpts_path,
head_trained_ckpts_path,
)

def forward(self, img):
"""Forward pass of the model."""
Expand Down Expand Up @@ -829,7 +872,8 @@ class CentroidModel(TrainingModel):
(iii) trainer_config: trainer configs like accelerator, optimiser params.
skeletons: List of `sio.Skeleton` objects from the input `.slp` file.
model_type: Type of the model. One of `single_instance`, `centered_instance`, `centroid`, `bottomup`.
trained_ckpts_path: Path to trained ckpts.
backbone_trained_ckpts_path: Path to trained ckpts for backbone.
head_trained_ckpts_path: Path to trained ckpts for head layer.

"""

Expand All @@ -838,10 +882,17 @@ def __init__(
config: OmegaConf,
skeletons: Optional[List[sio.Skeleton]],
model_type: str,
trained_ckpts_path: str = None,
backbone_trained_ckpts_path: Optional[str] = None,
head_trained_ckpts_path: Optional[str] = None,
):
"""Initialise the configs and the model."""
super().__init__(config, skeletons, model_type, trained_ckpts_path)
super().__init__(
config,
skeletons,
model_type,
backbone_trained_ckpts_path,
head_trained_ckpts_path,
)

def forward(self, img):
"""Forward pass of the model."""
Expand Down Expand Up @@ -902,7 +953,8 @@ class BottomUpModel(TrainingModel):
(iii) trainer_config: trainer configs like accelerator, optimiser params.
skeletons: List of `sio.Skeleton` objects from the input `.slp` file.
model_type: Type of the model. One of `single_instance`, `centered_instance`, `centroid`, `bottomup`.
trained_ckpts_path: Path to trained ckpts.
backbone_trained_ckpts_path: Path to trained ckpts for backbone.
head_trained_ckpts_path: Path to trained ckpts for head layer.

"""

Expand All @@ -911,10 +963,17 @@ def __init__(
config: OmegaConf,
skeletons: Optional[List[sio.Skeleton]],
model_type: str,
trained_ckpts_path: str = None,
backbone_trained_ckpts_path: Optional[str] = None,
head_trained_ckpts_path: Optional[str] = None,
):
"""Initialise the configs and the model."""
super().__init__(config, skeletons, model_type, trained_ckpts_path)
super().__init__(
config,
skeletons,
model_type,
backbone_trained_ckpts_path,
head_trained_ckpts_path,
)

def forward(self, img):
"""Forward pass of the model."""
Expand Down
2 changes: 1 addition & 1 deletion tests/fixtures/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def config(sleap_data_dir):
"stacks": 1,
"stem_stride": None,
"middle_block": True,
"up_interpolate": True,
"up_interpolate": False,
},
"head_configs": {
"single_instance": None,
Expand Down
36 changes: 29 additions & 7 deletions tests/training/test_model_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,24 +342,46 @@ def test_trainer(config, tmp_path: str, minimal_instance_bottomup_ckpt: str):

OmegaConf.update(config, "trainer_config.lr_scheduler.scheduler", "StepLR")

# check loading trained weights

def test_trainer_load_trained_ckpts(config, tmp_path, minimal_instance_ckpt):
"""Test loading trained weights for backbone and head layers."""

OmegaConf.update(
config, "trainer_config.save_ckpt_path", f"{tmp_path}/test_model_trainer/"
)
OmegaConf.update(config, "trainer_config.save_ckpt", True)
OmegaConf.update(config, "trainer_config.use_wandb", True)
OmegaConf.update(config, "data_config.preprocessing.crop_hw", None)
OmegaConf.update(config, "data_config.preprocessing.min_crop_size", 100)

# check loading trained weights for backbone
load_weights_config = config.copy()
ckpt = torch.load((Path(minimal_instance_bottomup_ckpt) / "best.ckpt").as_posix())
ckpt = torch.load((Path(minimal_instance_ckpt) / "best.ckpt").as_posix())
first_layer_ckpt = ckpt["state_dict"][
"model.backbone.enc.encoder_stack.0.blocks.0.weight"
][0, 0, :].numpy()

# load head ckpts
head_layer_ckpt = ckpt["state_dict"]["model.head_layers.0.0.weight"][
0, 0, :
].numpy()

trainer = ModelTrainer(load_weights_config)
trainer._create_data_loaders()
trainer._initialize_model(
(Path(minimal_instance_bottomup_ckpt) / "best.ckpt").as_posix()
backbone_trained_ckpts_path=(
Path(minimal_instance_ckpt) / "best.ckpt"
).as_posix(),
head_trained_ckpts_path=(Path(minimal_instance_ckpt) / "best.ckpt").as_posix(),
)
model_ckpt = next(trainer.model.parameters())[0, 0, :].detach().numpy()

assert np.all(np.abs(first_layer_ckpt - model_ckpt) < 1e-3)
assert np.all(np.abs(first_layer_ckpt - model_ckpt) < 1e-6)

shutil.rmtree((Path(trainer.bin_files_path) / "train_chunks").as_posix())
shutil.rmtree((Path(trainer.bin_files_path) / "val_chunks").as_posix())
model_ckpt = (
next(trainer.model.model.head_layers.parameters())[0, 0, :].detach().numpy()
)

assert np.all(np.abs(head_layer_ckpt - model_ckpt) < 1e-6)


def test_topdown_centered_instance_model(config, tmp_path: str):
Expand Down
Loading