Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check for model initialization during eval #9

Merged
merged 9 commits into from
Oct 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions gluefactory/eval/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,11 @@ def load_model(model_conf, checkpoint):
model = load_experiment(checkpoint, conf=model_conf).eval()
else:
model = get_model("two_view_pipeline")(model_conf).eval()
if not model.is_initialized():
raise ValueError(
"The provided model has non-initialized parameters. "
+ "Try to load a checkpoint instead."
)
return model


Expand Down
1 change: 1 addition & 0 deletions gluefactory/models/backbones/dinov2.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ class DinoV2(BaseModel):

def _init(self, conf):
self.net = torch.hub.load("facebookresearch/dinov2", conf.weights)
self.set_initialized()

def _forward(self, data):
img = data["image"]
Expand Down
30 changes: 30 additions & 0 deletions gluefactory/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ class BaseModel(nn.Module, metaclass=MetaModel):
required_data_keys = []
strict_conf = False

are_weights_initialized = False

def __init__(self, conf):
"""Perform some logic and call the _init method of the child model."""
super().__init__()
Expand Down Expand Up @@ -125,3 +127,31 @@ def _forward(self, data):
def loss(self, pred, data):
"""To be implemented by the child class."""
raise NotImplementedError

def load_state_dict(self, *args, **kwargs):
"""Load the state dict of the model, and set the model to initialized."""
ret = super().load_state_dict(*args, **kwargs)
self.set_initialized()
return ret

def is_initialized(self):
"""Recursively check if the model is initialized, i.e. weights are loaded"""
is_initialized = True # initialize to true and perform recursive and
for _, w in self.named_children():
if isinstance(w, BaseModel):
# if children is BaseModel, we perform recursive check
is_initialized = is_initialized and w.is_initialized()
else:
# else, we check if self is initialized or the children has no params
n_params = len(list(w.parameters()))
is_initialized = is_initialized and (
n_params == 0 or self.are_weights_initialized
)
return is_initialized

def set_initialized(self, to: bool = True):
"""Recursively set the initialization state."""
self.are_weights_initialized = to
for _, w in self.named_parameters():
if isinstance(w, BaseModel):
w.set_initialized(to)
1 change: 1 addition & 0 deletions gluefactory/models/extractors/disk_kornia.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class DISK(BaseModel):

def _init(self, conf):
self.model = kornia.feature.DISK.from_pretrained(conf.weights)
self.set_initialized()

def _get_dense_outputs(self, images):
B = images.shape[0]
Expand Down
1 change: 1 addition & 0 deletions gluefactory/models/extractors/keynet_affnet_hardnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def _init(self, conf):
upright=conf.upright,
scale_laf=conf.scale_laf,
)
self.set_initialized()

def _forward(self, data):
image = data["image"]
Expand Down
1 change: 1 addition & 0 deletions gluefactory/models/extractors/sift_kornia.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def _init(self, conf):
self.sift = kornia.feature.SIFTFeature(
num_features=self.conf.max_num_keypoints, rootsift=self.conf.rootsift
)
self.set_initialized()

def _forward(self, data):
lafs, scores, descriptors = self.sift(data["image"])
Expand Down
1 change: 1 addition & 0 deletions gluefactory/models/lines/deeplsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def _init(self, conf):
ckpt = torch.load(ckpt, map_location="cpu")
self.net = deeplsd_inference.DeepLSD(conf.model_conf).eval()
self.net.load_state_dict(ckpt["model"])
self.set_initialized()

def download_model(self, path):
import subprocess
Expand Down
1 change: 1 addition & 0 deletions gluefactory/models/matchers/kornia_loftr.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class LoFTRModule(BaseModel):

def _init(self, conf):
self.net = kornia.feature.LoFTR(pretrained="outdoor")
self.set_initialized()

def _forward(self, data):
image0 = data["view0"]["image"]
Expand Down
2 changes: 1 addition & 1 deletion gluefactory/models/matchers/lightglue_pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class LightGlue(BaseModel):
def _init(self, conf):
dconf = OmegaConf.to_container(conf)
self.net = LightGlue_(dconf.pop("features"), **dconf).cuda()
# self.net.compile()
self.set_initialized()

def _forward(self, data):
view0 = {
Expand Down