From 398c4b8c21cea06849cf249236b5f2da101305b5 Mon Sep 17 00:00:00 2001 From: Philipp Lindenberger Date: Fri, 13 Oct 2023 13:19:19 +0200 Subject: [PATCH] Fix sparse depth export for megadepth (#13) * Never export sparse depth * add sparse depth support --- gluefactory/models/cache_loader.py | 9 +++++++++ gluefactory/scripts/export_megadepth.py | 10 +++++++--- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/gluefactory/models/cache_loader.py b/gluefactory/models/cache_loader.py index 3fbf0f71..b345a997 100644 --- a/gluefactory/models/cache_loader.py +++ b/gluefactory/models/cache_loader.py @@ -29,6 +29,15 @@ def pad_local_features(pred: dict, seq_l: int): pred["scales"] = pad_to_length(pred["scales"], seq_l, -1, mode="zeros") if "oris" in pred.keys(): pred["oris"] = pad_to_length(pred["oris"], seq_l, -1, mode="zeros") + + if "depth_keypoints" in pred.keys(): + pred["depth_keypoints"] = pad_to_length( + pred["depth_keypoints"], seq_l, -1, mode="zeros" + ) + if "valid_depth_keypoints" in pred.keys(): + pred["valid_depth_keypoints"] = pad_to_length( + pred["valid_depth_keypoints"], seq_l, -1, mode="zeros" + ) return pred diff --git a/gluefactory/scripts/export_megadepth.py b/gluefactory/scripts/export_megadepth.py index 95e89d81..f332f2fa 100644 --- a/gluefactory/scripts/export_megadepth.py +++ b/gluefactory/scripts/export_megadepth.py @@ -133,15 +133,18 @@ def run_export(feature_file, scene, args): conf = OmegaConf.create(conf) - keys = configs[args.method]["keys"] + ["depth_keypoints", "valid_depth_keypoints"] + keys = configs[args.method]["keys"] dataset = get_dataset(conf.data.name)(conf.data) loader = dataset.get_data_loader(conf.split or "test") device = "cuda" if torch.cuda.is_available() else "cpu" model = get_model(conf.model.name)(conf.model).eval().to(device) - callback_fn = None - # callback_fn=get_kp_depth # use this to store the depth of each keypoint + if args.export_sparse_depth: + callback_fn = get_kp_depth # use this to store the depth of each keypoint + keys = keys + ["depth_keypoints", "valid_depth_keypoints"] + else: + callback_fn = None export_predictions( loader, model, feature_file, as_half=True, keys=keys, callback_fn=callback_fn ) @@ -153,6 +156,7 @@ def run_export(feature_file, scene, args): parser.add_argument("--method", type=str, default="sp") parser.add_argument("--scenes", type=str, default=None) parser.add_argument("--num_workers", type=int, default=0) + parser.add_argument("--export_sparse_depth", action="store_true") args = parser.parse_args() export_name = configs[args.method]["name"]