Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check for model initialization during eval #9

Merged
merged 9 commits into from
Oct 10, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions gluefactory/eval/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,11 @@ def load_model(model_conf, checkpoint):
model = load_experiment(checkpoint, conf=model_conf).eval()
else:
model = get_model("two_view_pipeline")(model_conf).eval()
if not model.is_initialized():
assert model.is_initialized(), (
Phil26AT marked this conversation as resolved.
Show resolved Hide resolved
"The provided model has non-initialized parameters. "
+ "Try to load a checkpoint instead."
)
return model


Expand Down
1 change: 1 addition & 0 deletions gluefactory/models/backbones/dinov2.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ class DinoV2(BaseModel):

def _init(self, conf):
self.net = torch.hub.load("facebookresearch/dinov2", conf.weights)
self.set_initialized(True)
Phil26AT marked this conversation as resolved.
Show resolved Hide resolved

def _forward(self, data):
img = data["image"]
Expand Down
31 changes: 31 additions & 0 deletions gluefactory/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import omegaconf
from omegaconf import OmegaConf
from torch import nn
from typing import Mapping, Any


class MetaModel(ABCMeta):
Expand Down Expand Up @@ -60,6 +61,8 @@ class BaseModel(nn.Module, metaclass=MetaModel):
required_data_keys = []
strict_conf = False

weights_initialized = False
Phil26AT marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, conf):
"""Perform some logic and call the _init method of the child model."""
super().__init__()
Expand Down Expand Up @@ -125,3 +128,31 @@ def _forward(self, data):
def loss(self, pred, data):
"""To be implemented by the child class."""
raise NotImplementedError

def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
"""Load the state dict of the model, and set the model to initialized."""
incompatible_keys = super().load_state_dict(state_dict, strict=strict)
self.set_initialized(True)
return incompatible_keys
Phil26AT marked this conversation as resolved.
Show resolved Hide resolved

def is_initialized(self):
"""Recursively check if the model is initialized, i.e. weights are loaded"""
is_initialized = True # initialize to true and perform recursive and
for _, w in self.named_children():
if isinstance(w, BaseModel):
# if children is BaseModel, we perform recursive check
is_initialized = is_initialized and w.is_initialized()
else:
# else, we check if self is initialized or the children has no params
n_params = len(list(w.parameters()))
is_initialized = is_initialized and (
n_params == 0 or self.weights_initialized
)
return is_initialized

def set_initialized(self, to: bool = True):
"""Recursively set the initialization state."""
self.weights_initialized = to
for _, w in self.named_parameters():
if isinstance(w, BaseModel):
w.set_initialized(to)
1 change: 1 addition & 0 deletions gluefactory/models/extractors/disk_kornia.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class DISK(BaseModel):

def _init(self, conf):
self.model = kornia.feature.DISK.from_pretrained(conf.weights)
self.set_initialized(True)

def _get_dense_outputs(self, images):
B = images.shape[0]
Expand Down
1 change: 1 addition & 0 deletions gluefactory/models/extractors/keynet_affnet_hardnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def _init(self, conf):
upright=conf.upright,
scale_laf=conf.scale_laf,
)
self.set_initialized(True)

def _forward(self, data):
image = data["image"]
Expand Down
1 change: 1 addition & 0 deletions gluefactory/models/extractors/sift_kornia.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def _init(self, conf):
self.sift = kornia.feature.SIFTFeature(
num_features=self.conf.max_num_keypoints, rootsift=self.conf.rootsift
)
self.set_initialized(True)

def _forward(self, data):
lafs, scores, descriptors = self.sift(data["image"])
Expand Down
1 change: 1 addition & 0 deletions gluefactory/models/lines/deeplsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def _init(self, conf):
ckpt = torch.load(ckpt, map_location="cpu")
self.net = deeplsd_inference.DeepLSD(conf.model_conf).eval()
self.net.load_state_dict(ckpt["model"])
self.set_initialized(True)

def download_model(self, path):
import subprocess
Expand Down
1 change: 1 addition & 0 deletions gluefactory/models/matchers/kornia_loftr.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class LoFTRModule(BaseModel):

def _init(self, conf):
self.net = kornia.feature.LoFTR(pretrained="outdoor")
self.set_initialized(True)

def _forward(self, data):
image0 = data["view0"]["image"]
Expand Down
2 changes: 1 addition & 1 deletion gluefactory/models/matchers/lightglue_pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class LightGlue(BaseModel):
def _init(self, conf):
dconf = OmegaConf.to_container(conf)
self.net = LightGlue_(dconf.pop("features"), **dconf).cuda()
# self.net.compile()
self.weights_initialized = True
Phil26AT marked this conversation as resolved.
Show resolved Hide resolved

def _forward(self, data):
view0 = {
Expand Down
Loading