Skip to content

Commit

Permalink
Check for model initialization in eval (#9)
Browse files Browse the repository at this point in the history
* Check for model initialization during eval

* Change from warning to assert

* change assertion to ValueError

* fix isort and rename are_weights_initialized

* cleanup set_initialized()

* fix variable name bug

* Make load_state_dict forward compatible

Co-authored-by: Paul-Edouard Sarlin <[email protected]>

* Remove unused imports

---------

Co-authored-by: Paul-Edouard Sarlin <[email protected]>
  • Loading branch information
Phil26AT and sarlinpe authored Oct 10, 2023
1 parent f7b587e commit 22154a6
Show file tree
Hide file tree
Showing 9 changed files with 42 additions and 1 deletion.
5 changes: 5 additions & 0 deletions gluefactory/eval/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,11 @@ def load_model(model_conf, checkpoint):
model = load_experiment(checkpoint, conf=model_conf).eval()
else:
model = get_model("two_view_pipeline")(model_conf).eval()
if not model.is_initialized():
raise ValueError(
"The provided model has non-initialized parameters. "
+ "Try to load a checkpoint instead."
)
return model


Expand Down
1 change: 1 addition & 0 deletions gluefactory/models/backbones/dinov2.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ class DinoV2(BaseModel):

def _init(self, conf):
self.net = torch.hub.load("facebookresearch/dinov2", conf.weights)
self.set_initialized()

def _forward(self, data):
img = data["image"]
Expand Down
30 changes: 30 additions & 0 deletions gluefactory/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ class BaseModel(nn.Module, metaclass=MetaModel):
required_data_keys = []
strict_conf = False

are_weights_initialized = False

def __init__(self, conf):
"""Perform some logic and call the _init method of the child model."""
super().__init__()
Expand Down Expand Up @@ -125,3 +127,31 @@ def _forward(self, data):
def loss(self, pred, data):
"""To be implemented by the child class."""
raise NotImplementedError

def load_state_dict(self, *args, **kwargs):
"""Load the state dict of the model, and set the model to initialized."""
ret = super().load_state_dict(*args, **kwargs)
self.set_initialized()
return ret

def is_initialized(self):
"""Recursively check if the model is initialized, i.e. weights are loaded"""
is_initialized = True # initialize to true and perform recursive and
for _, w in self.named_children():
if isinstance(w, BaseModel):
# if children is BaseModel, we perform recursive check
is_initialized = is_initialized and w.is_initialized()
else:
# else, we check if self is initialized or the children has no params
n_params = len(list(w.parameters()))
is_initialized = is_initialized and (
n_params == 0 or self.are_weights_initialized
)
return is_initialized

def set_initialized(self, to: bool = True):
"""Recursively set the initialization state."""
self.are_weights_initialized = to
for _, w in self.named_parameters():
if isinstance(w, BaseModel):
w.set_initialized(to)
1 change: 1 addition & 0 deletions gluefactory/models/extractors/disk_kornia.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class DISK(BaseModel):

def _init(self, conf):
self.model = kornia.feature.DISK.from_pretrained(conf.weights)
self.set_initialized()

def _get_dense_outputs(self, images):
B = images.shape[0]
Expand Down
1 change: 1 addition & 0 deletions gluefactory/models/extractors/keynet_affnet_hardnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def _init(self, conf):
upright=conf.upright,
scale_laf=conf.scale_laf,
)
self.set_initialized()

def _forward(self, data):
image = data["image"]
Expand Down
1 change: 1 addition & 0 deletions gluefactory/models/extractors/sift_kornia.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def _init(self, conf):
self.sift = kornia.feature.SIFTFeature(
num_features=self.conf.max_num_keypoints, rootsift=self.conf.rootsift
)
self.set_initialized()

def _forward(self, data):
lafs, scores, descriptors = self.sift(data["image"])
Expand Down
1 change: 1 addition & 0 deletions gluefactory/models/lines/deeplsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def _init(self, conf):
ckpt = torch.load(ckpt, map_location="cpu")
self.net = deeplsd_inference.DeepLSD(conf.model_conf).eval()
self.net.load_state_dict(ckpt["model"])
self.set_initialized()

def download_model(self, path):
import subprocess
Expand Down
1 change: 1 addition & 0 deletions gluefactory/models/matchers/kornia_loftr.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class LoFTRModule(BaseModel):

def _init(self, conf):
self.net = kornia.feature.LoFTR(pretrained="outdoor")
self.set_initialized()

def _forward(self, data):
image0 = data["view0"]["image"]
Expand Down
2 changes: 1 addition & 1 deletion gluefactory/models/matchers/lightglue_pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class LightGlue(BaseModel):
def _init(self, conf):
dconf = OmegaConf.to_container(conf)
self.net = LightGlue_(dconf.pop("features"), **dconf).cuda()
# self.net.compile()
self.set_initialized()

def _forward(self, data):
view0 = {
Expand Down

0 comments on commit 22154a6

Please sign in to comment.