From ed117cd78091e48340c3aed1e9574ba64159f199 Mon Sep 17 00:00:00 2001 From: Phil26AT Date: Mon, 9 Oct 2023 19:12:40 +0200 Subject: [PATCH 1/8] Check for model initialization during eval --- gluefactory/eval/io.py | 3 ++ gluefactory/models/backbones/dinov2.py | 1 + gluefactory/models/base_model.py | 31 +++++++++++++++++++ gluefactory/models/extractors/disk_kornia.py | 1 + .../extractors/keynet_affnet_hardnet.py | 1 + gluefactory/models/extractors/sift_kornia.py | 1 + gluefactory/models/lines/deeplsd.py | 1 + gluefactory/models/matchers/kornia_loftr.py | 1 + .../models/matchers/lightglue_pretrained.py | 2 +- 9 files changed, 41 insertions(+), 1 deletion(-) diff --git a/gluefactory/eval/io.py b/gluefactory/eval/io.py index 067e8456..26d2339c 100644 --- a/gluefactory/eval/io.py +++ b/gluefactory/eval/io.py @@ -2,6 +2,7 @@ from pathlib import Path from pprint import pprint from typing import Optional +import warnings import pkg_resources from omegaconf import OmegaConf @@ -89,6 +90,8 @@ def load_model(model_conf, checkpoint): model = load_experiment(checkpoint, conf=model_conf).eval() else: model = get_model("two_view_pipeline")(model_conf).eval() + if not model.is_initialized(): + warnings.warn("The provided input did not initialize all parameters. Aborting.") return model diff --git a/gluefactory/models/backbones/dinov2.py b/gluefactory/models/backbones/dinov2.py index 48a48b59..6a2121d7 100644 --- a/gluefactory/models/backbones/dinov2.py +++ b/gluefactory/models/backbones/dinov2.py @@ -10,6 +10,7 @@ class DinoV2(BaseModel): def _init(self, conf): self.net = torch.hub.load("facebookresearch/dinov2", conf.weights) + self.set_initialized(True) def _forward(self, data): img = data["image"] diff --git a/gluefactory/models/base_model.py b/gluefactory/models/base_model.py index 7313d986..9dc5c231 100644 --- a/gluefactory/models/base_model.py +++ b/gluefactory/models/base_model.py @@ -8,6 +8,7 @@ import omegaconf from omegaconf import OmegaConf from torch import nn +from typing import Mapping, Any class MetaModel(ABCMeta): @@ -60,6 +61,8 @@ class BaseModel(nn.Module, metaclass=MetaModel): required_data_keys = [] strict_conf = False + weights_initialized = False + def __init__(self, conf): """Perform some logic and call the _init method of the child model.""" super().__init__() @@ -125,3 +128,31 @@ def _forward(self, data): def loss(self, pred, data): """To be implemented by the child class.""" raise NotImplementedError + + def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): + """Load the state dict of the model, and set the model to initialized.""" + incompatible_keys = super().load_state_dict(state_dict, strict=strict) + self.set_initialized(True) + return incompatible_keys + + def is_initialized(self): + """Recursively check if the model is initialized, i.e. weights are loaded""" + is_initialized = True # initialize to true and perform recursive and + for _, w in self.named_children(): + if isinstance(w, BaseModel): + # if children is BaseModel, we perform recursive check + is_initialized = is_initialized and w.is_initialized() + else: + # else, we check if self is initialized or the children has no params + n_params = len(list(w.parameters())) + is_initialized = is_initialized and ( + n_params == 0 or self.weights_initialized + ) + return is_initialized + + def set_initialized(self, to: bool = True): + """Recursively set the initialization state.""" + self.weights_initialized = to + for _, w in self.named_parameters(): + if isinstance(w, BaseModel): + w.set_initialized(to) diff --git a/gluefactory/models/extractors/disk_kornia.py b/gluefactory/models/extractors/disk_kornia.py index 4d60973d..3f39a369 100644 --- a/gluefactory/models/extractors/disk_kornia.py +++ b/gluefactory/models/extractors/disk_kornia.py @@ -21,6 +21,7 @@ class DISK(BaseModel): def _init(self, conf): self.model = kornia.feature.DISK.from_pretrained(conf.weights) + self.set_initialized(True) def _get_dense_outputs(self, images): B = images.shape[0] diff --git a/gluefactory/models/extractors/keynet_affnet_hardnet.py b/gluefactory/models/extractors/keynet_affnet_hardnet.py index b9091ea4..633cf481 100644 --- a/gluefactory/models/extractors/keynet_affnet_hardnet.py +++ b/gluefactory/models/extractors/keynet_affnet_hardnet.py @@ -21,6 +21,7 @@ def _init(self, conf): upright=conf.upright, scale_laf=conf.scale_laf, ) + self.set_initialized(True) def _forward(self, data): image = data["image"] diff --git a/gluefactory/models/extractors/sift_kornia.py b/gluefactory/models/extractors/sift_kornia.py index 78810e66..c761ad98 100644 --- a/gluefactory/models/extractors/sift_kornia.py +++ b/gluefactory/models/extractors/sift_kornia.py @@ -19,6 +19,7 @@ def _init(self, conf): self.sift = kornia.feature.SIFTFeature( num_features=self.conf.max_num_keypoints, rootsift=self.conf.rootsift ) + self.set_initialized(True) def _forward(self, data): lafs, scores, descriptors = self.sift(data["image"]) diff --git a/gluefactory/models/lines/deeplsd.py b/gluefactory/models/lines/deeplsd.py index c35aa01e..35e9b855 100644 --- a/gluefactory/models/lines/deeplsd.py +++ b/gluefactory/models/lines/deeplsd.py @@ -34,6 +34,7 @@ def _init(self, conf): ckpt = torch.load(ckpt, map_location="cpu") self.net = deeplsd_inference.DeepLSD(conf.model_conf).eval() self.net.load_state_dict(ckpt["model"]) + self.set_initialized(True) def download_model(self, path): import subprocess diff --git a/gluefactory/models/matchers/kornia_loftr.py b/gluefactory/models/matchers/kornia_loftr.py index 45a20b7a..29299121 100644 --- a/gluefactory/models/matchers/kornia_loftr.py +++ b/gluefactory/models/matchers/kornia_loftr.py @@ -13,6 +13,7 @@ class LoFTRModule(BaseModel): def _init(self, conf): self.net = kornia.feature.LoFTR(pretrained="outdoor") + self.set_initialized(True) def _forward(self, data): image0 = data["view0"]["image"] diff --git a/gluefactory/models/matchers/lightglue_pretrained.py b/gluefactory/models/matchers/lightglue_pretrained.py index 2e7c71b6..7fd9e686 100644 --- a/gluefactory/models/matchers/lightglue_pretrained.py +++ b/gluefactory/models/matchers/lightglue_pretrained.py @@ -18,7 +18,7 @@ class LightGlue(BaseModel): def _init(self, conf): dconf = OmegaConf.to_container(conf) self.net = LightGlue_(dconf.pop("features"), **dconf).cuda() - # self.net.compile() + self.weights_initialized = True def _forward(self, data): view0 = { From 242a03704bfc709417444ba2a26ee2bf683a4f00 Mon Sep 17 00:00:00 2001 From: Phil26AT Date: Mon, 9 Oct 2023 19:25:07 +0200 Subject: [PATCH 2/8] Change from warning to assert --- gluefactory/eval/io.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/gluefactory/eval/io.py b/gluefactory/eval/io.py index 26d2339c..f0345efc 100644 --- a/gluefactory/eval/io.py +++ b/gluefactory/eval/io.py @@ -2,7 +2,6 @@ from pathlib import Path from pprint import pprint from typing import Optional -import warnings import pkg_resources from omegaconf import OmegaConf @@ -91,7 +90,10 @@ def load_model(model_conf, checkpoint): else: model = get_model("two_view_pipeline")(model_conf).eval() if not model.is_initialized(): - warnings.warn("The provided input did not initialize all parameters. Aborting.") + assert model.is_initialized(), ( + "The provided model has non-initialized parameters. " + + "Try to load a checkpoint instead." + ) return model From 79e30ad95c734cd1b33cf8d125a88493c5771bce Mon Sep 17 00:00:00 2001 From: Phil26AT Date: Tue, 10 Oct 2023 12:37:35 +0200 Subject: [PATCH 3/8] change assertion to ValueError --- gluefactory/eval/io.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gluefactory/eval/io.py b/gluefactory/eval/io.py index f0345efc..6a55d59e 100644 --- a/gluefactory/eval/io.py +++ b/gluefactory/eval/io.py @@ -90,7 +90,7 @@ def load_model(model_conf, checkpoint): else: model = get_model("two_view_pipeline")(model_conf).eval() if not model.is_initialized(): - assert model.is_initialized(), ( + raise ValueError( "The provided model has non-initialized parameters. " + "Try to load a checkpoint instead." ) From 15c94323820eb4ec2d26d1c21ab0b4d107dee62a Mon Sep 17 00:00:00 2001 From: Phil26AT Date: Tue, 10 Oct 2023 12:38:20 +0200 Subject: [PATCH 4/8] fix isort and rename are_weights_initialized --- gluefactory/models/base_model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/gluefactory/models/base_model.py b/gluefactory/models/base_model.py index 9dc5c231..e50197cc 100644 --- a/gluefactory/models/base_model.py +++ b/gluefactory/models/base_model.py @@ -4,11 +4,11 @@ from abc import ABCMeta, abstractmethod from copy import copy +from typing import Any, Mapping import omegaconf from omegaconf import OmegaConf from torch import nn -from typing import Mapping, Any class MetaModel(ABCMeta): @@ -61,7 +61,7 @@ class BaseModel(nn.Module, metaclass=MetaModel): required_data_keys = [] strict_conf = False - weights_initialized = False + are_weights_initialized = False def __init__(self, conf): """Perform some logic and call the _init method of the child model.""" @@ -146,7 +146,7 @@ def is_initialized(self): # else, we check if self is initialized or the children has no params n_params = len(list(w.parameters())) is_initialized = is_initialized and ( - n_params == 0 or self.weights_initialized + n_params == 0 or self.are_weights_initialized ) return is_initialized From b2a59ab82efdb2bd4a0bd5876dc6a9988c869fea Mon Sep 17 00:00:00 2001 From: Phil26AT Date: Tue, 10 Oct 2023 12:38:35 +0200 Subject: [PATCH 5/8] cleanup set_initialized() --- gluefactory/models/backbones/dinov2.py | 2 +- gluefactory/models/extractors/disk_kornia.py | 2 +- gluefactory/models/extractors/keynet_affnet_hardnet.py | 2 +- gluefactory/models/extractors/sift_kornia.py | 2 +- gluefactory/models/lines/deeplsd.py | 2 +- gluefactory/models/matchers/kornia_loftr.py | 2 +- gluefactory/models/matchers/lightglue_pretrained.py | 2 +- 7 files changed, 7 insertions(+), 7 deletions(-) diff --git a/gluefactory/models/backbones/dinov2.py b/gluefactory/models/backbones/dinov2.py index 6a2121d7..cf828523 100644 --- a/gluefactory/models/backbones/dinov2.py +++ b/gluefactory/models/backbones/dinov2.py @@ -10,7 +10,7 @@ class DinoV2(BaseModel): def _init(self, conf): self.net = torch.hub.load("facebookresearch/dinov2", conf.weights) - self.set_initialized(True) + self.set_initialized() def _forward(self, data): img = data["image"] diff --git a/gluefactory/models/extractors/disk_kornia.py b/gluefactory/models/extractors/disk_kornia.py index 3f39a369..e01ab89d 100644 --- a/gluefactory/models/extractors/disk_kornia.py +++ b/gluefactory/models/extractors/disk_kornia.py @@ -21,7 +21,7 @@ class DISK(BaseModel): def _init(self, conf): self.model = kornia.feature.DISK.from_pretrained(conf.weights) - self.set_initialized(True) + self.set_initialized() def _get_dense_outputs(self, images): B = images.shape[0] diff --git a/gluefactory/models/extractors/keynet_affnet_hardnet.py b/gluefactory/models/extractors/keynet_affnet_hardnet.py index 633cf481..419ee972 100644 --- a/gluefactory/models/extractors/keynet_affnet_hardnet.py +++ b/gluefactory/models/extractors/keynet_affnet_hardnet.py @@ -21,7 +21,7 @@ def _init(self, conf): upright=conf.upright, scale_laf=conf.scale_laf, ) - self.set_initialized(True) + self.set_initialized() def _forward(self, data): image = data["image"] diff --git a/gluefactory/models/extractors/sift_kornia.py b/gluefactory/models/extractors/sift_kornia.py index c761ad98..7a1e74d2 100644 --- a/gluefactory/models/extractors/sift_kornia.py +++ b/gluefactory/models/extractors/sift_kornia.py @@ -19,7 +19,7 @@ def _init(self, conf): self.sift = kornia.feature.SIFTFeature( num_features=self.conf.max_num_keypoints, rootsift=self.conf.rootsift ) - self.set_initialized(True) + self.set_initialized() def _forward(self, data): lafs, scores, descriptors = self.sift(data["image"]) diff --git a/gluefactory/models/lines/deeplsd.py b/gluefactory/models/lines/deeplsd.py index 35e9b855..122f4b4f 100644 --- a/gluefactory/models/lines/deeplsd.py +++ b/gluefactory/models/lines/deeplsd.py @@ -34,7 +34,7 @@ def _init(self, conf): ckpt = torch.load(ckpt, map_location="cpu") self.net = deeplsd_inference.DeepLSD(conf.model_conf).eval() self.net.load_state_dict(ckpt["model"]) - self.set_initialized(True) + self.set_initialized() def download_model(self, path): import subprocess diff --git a/gluefactory/models/matchers/kornia_loftr.py b/gluefactory/models/matchers/kornia_loftr.py index 29299121..6fbd47b0 100644 --- a/gluefactory/models/matchers/kornia_loftr.py +++ b/gluefactory/models/matchers/kornia_loftr.py @@ -13,7 +13,7 @@ class LoFTRModule(BaseModel): def _init(self, conf): self.net = kornia.feature.LoFTR(pretrained="outdoor") - self.set_initialized(True) + self.set_initialized() def _forward(self, data): image0 = data["view0"]["image"] diff --git a/gluefactory/models/matchers/lightglue_pretrained.py b/gluefactory/models/matchers/lightglue_pretrained.py index 7fd9e686..b23976db 100644 --- a/gluefactory/models/matchers/lightglue_pretrained.py +++ b/gluefactory/models/matchers/lightglue_pretrained.py @@ -18,7 +18,7 @@ class LightGlue(BaseModel): def _init(self, conf): dconf = OmegaConf.to_container(conf) self.net = LightGlue_(dconf.pop("features"), **dconf).cuda() - self.weights_initialized = True + self.set_initialized() def _forward(self, data): view0 = { From a42637b53af589fbe066b0530bae5e3a1e064041 Mon Sep 17 00:00:00 2001 From: Phil26AT Date: Tue, 10 Oct 2023 12:44:56 +0200 Subject: [PATCH 6/8] fix variable name bug --- gluefactory/models/base_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gluefactory/models/base_model.py b/gluefactory/models/base_model.py index e50197cc..515b4b7f 100644 --- a/gluefactory/models/base_model.py +++ b/gluefactory/models/base_model.py @@ -152,7 +152,7 @@ def is_initialized(self): def set_initialized(self, to: bool = True): """Recursively set the initialization state.""" - self.weights_initialized = to + self.are_weights_initialized = to for _, w in self.named_parameters(): if isinstance(w, BaseModel): w.set_initialized(to) From 62596860978b847b6cdcd57ee2a1fbd403517362 Mon Sep 17 00:00:00 2001 From: Philipp Lindenberger Date: Tue, 10 Oct 2023 13:15:55 +0200 Subject: [PATCH 7/8] Make load_state_dict forward compatible Co-authored-by: Paul-Edouard Sarlin <15985472+sarlinpe@users.noreply.github.com> --- gluefactory/models/base_model.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/gluefactory/models/base_model.py b/gluefactory/models/base_model.py index 515b4b7f..5579eddc 100644 --- a/gluefactory/models/base_model.py +++ b/gluefactory/models/base_model.py @@ -129,11 +129,11 @@ def loss(self, pred, data): """To be implemented by the child class.""" raise NotImplementedError - def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): + def load_state_dict(self, *args, **kwargs): """Load the state dict of the model, and set the model to initialized.""" - incompatible_keys = super().load_state_dict(state_dict, strict=strict) - self.set_initialized(True) - return incompatible_keys + ret = super().load_state_dict(*args, **kwargs) + self.set_initialized() + return ret def is_initialized(self): """Recursively check if the model is initialized, i.e. weights are loaded""" From 66b9f679943d9b9e92fc51b5c9039f30ca37f305 Mon Sep 17 00:00:00 2001 From: Phil26AT Date: Tue, 10 Oct 2023 13:16:40 +0200 Subject: [PATCH 8/8] Remove unused imports --- gluefactory/models/base_model.py | 1 - 1 file changed, 1 deletion(-) diff --git a/gluefactory/models/base_model.py b/gluefactory/models/base_model.py index 5579eddc..b4f66288 100644 --- a/gluefactory/models/base_model.py +++ b/gluefactory/models/base_model.py @@ -4,7 +4,6 @@ from abc import ABCMeta, abstractmethod from copy import copy -from typing import Any, Mapping import omegaconf from omegaconf import OmegaConf