diff --git a/README.md b/README.md
index 95f060d4..99e7d415 100644
--- a/README.md
+++ b/README.md
@@ -188,6 +188,95 @@ AP_lines: 69.22
+#### EVD
+
+The dataset will be auto-downloaded if it is not found on disk, and will need about 27 Mb of free disk space.
+
+
+[Evaluating LightGlue]
+
+To evaluate LightGlue on EVD, run:
+```bash
+python -m gluefactory.eval.evd --conf gluefactory/configs/superpoint+lightglue-official.yaml
+```
+You should expect the following results
+```
+{'H_error_dlt@10px': 0.0808,
+ 'H_error_dlt@1px': 0.0,
+ 'H_error_dlt@20px': 0.1443,
+ 'H_error_dlt@5px': 0.0,
+ 'H_error_ransac@10px': 0.1045,
+ 'H_error_ransac@1px': 0.0,
+ 'H_error_ransac@20px': 0.1189,
+ 'H_error_ransac@5px': 0.0553,
+ 'H_error_ransac_mAA': 0.069675,
+ 'mH_error_dlt': nan,
+ 'mH_error_ransac': nan,
+ 'mnum_keypoints': 2048.0,
+ 'mnum_matches': 11.0,
+ 'mprec@1px': 0.0,
+ 'mprec@3px': 0.0,
+ 'mransac_inl': 5.0,
+ 'mransac_inl%': 0.089}
+```
+
+Here are the results as Area Under the Curve (AUC) of the homography error at 1/5/10/20 pixels:
+
+[LightGlue on EVD]
+
+| Methods (2K features if not specified) | [PoseLib](../gluefactory/robust_estimators/homography/poselib.py) |
+| ----------------------------------------------------------- | ---------------------------------------------------------------------- |
+| [SuperPoint + SuperGlue](gluefactory/configs/superpoint+superglue-official.yaml) | 0.0 / 5.4 / 10.1 / 11.7 |
+| [SuperPoint + LightGlue](gluefactory/configs/superpoint+lightglue-official.yaml) | 0.0 / 5.5 / 10.4 / 11.8 |
+| [SIFT (4K) + LightGlue](gluefactory/configs/sift+lightglue-official.yaml) | 0.0 / 3.8 / 5.2 / 10.0 |
+| [DoGHardNet + LightGlue](gluefactory/configs/doghardnet+lightglue-official.yaml) | 0.0 / 5.5 / 10.5 / 11.9 |
+| [ALIKED + LightGlue](gluefactory/configs/aliked+lightglue-official.yaml) | 0.0 / 5.4 / 12.4 / 16.2|
+| [DISK + LightGlue](gluefactory/configs/disk+lightglue-official.yaml) | 0.0 / 0.0 / 6.9 / 10.1 |
+
+
+
+
+
+#### WxBS
+
+The dataset will be auto-downloaded if it is not found on disk, and will need about 40 Mb of free disk space.
+
+
+[Evaluating LightGlue]
+
+To evaluate LightGlue on WxBS, run:
+```bash
+python -m gluefactory.eval.WxBS --conf gluefactory/configs/superpoint+lightglue-official.yaml
+```
+You should expect the following results
+```
+{'epi_error@10px': 0.6141352941176471,
+ 'epi_error@1px': 0.2968,
+ 'epi_error@20px': 0.6937882352941176,
+ 'epi_error@5px': 0.5143617647058826,
+ 'epi_error_mAA': 0.5297713235294118,
+ 'mnum_keypoints': 2048.0,
+ 'mnum_matches': 99.5,
+ 'mransac_inl': 65.0,
+ 'mransac_inl%': nan}
+```
+
+Here are the results as Area Under the Curve (AUC) of the epipolar error at 1/5/10/20 pixels:
+
+[LightGlue on WxBS]
+
+| Methods (2K features if not specified) | [PoseLib](../gluefactory/robust_estimators/fundamental_matrix/poselib.py) |
+| ----------------------------------------------------------- | ---------------------------------------------------------------------- |
+| [SuperPoint + SuperGlue](gluefactory/configs/superpoint+superglue-official.yaml) | 13.2 / 39.9 / 49.7 / 56.7 |
+| [SuperPoint + LightGlue](gluefactory/configs/superpoint+lightglue-official.yaml) | 12.6 / 34.5 / 44.0 / 52.2 |
+| [SIFT (4K) + LightGlue](gluefactory/configs/sift+lightglue-official.yaml) | 9.5 / 22.7 / 29.0 / 34.2 |
+| [DoGHardNet + LightGlue](gluefactory/configs/doghardnet+lightglue-official.yaml) | 10.0 / 29.6 / 39.0 / 49.2 |
+| [ALIKED + LightGlue](gluefactory/configs/aliked+lightglue-official.yaml) | 18.7 / 46.2 / 56.0 / 63.5 |
+| [DISK + LightGlue](gluefactory/configs/disk+lightglue-official.yaml) | 15.1 / 39.3 / 48.2 / 55.2 |
+
+
+
+
#### Image Matching Challenge 2021
Coming soon!
diff --git a/gluefactory/datasets/evd.py b/gluefactory/datasets/evd.py
new file mode 100644
index 00000000..47a0f5f7
--- /dev/null
+++ b/gluefactory/datasets/evd.py
@@ -0,0 +1,126 @@
+"""
+Simply load images from a folder or nested folders (does not have any split).
+"""
+import argparse
+import logging
+import zipfile
+import os
+
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+from omegaconf import OmegaConf
+
+from ..settings import DATA_PATH
+from ..utils.image import ImagePreprocessor, load_image
+from ..utils.tools import fork_rng
+from ..visualization.viz2d import plot_image_grid
+from .base_dataset import BaseDataset
+
+logger = logging.getLogger(__name__)
+
+
+def read_homography(path):
+ with open(path, 'r') as hf:
+ lines = hf.readlines()
+ H = []
+ for l in lines:
+ H.append([float(x) for x in l.replace('\t',' ').strip().split(' ') if len(x) > 0])
+ H = np.array(H)
+ H = H / H[2, 2]
+ return H
+
+class EVD(BaseDataset, torch.utils.data.Dataset):
+ default_conf = {
+ "preprocessing": ImagePreprocessor.default_conf,
+ "data_dir": "EVD",
+ "subset": None,
+ "grayscale": False,
+ }
+ url = "http://cmp.felk.cvut.cz/wbs/datasets/EVD.zip"
+
+ def _init(self, conf):
+ assert conf.batch_size == 1
+ self.preprocessor = ImagePreprocessor(conf.preprocessing)
+ self.root = DATA_PATH / conf.data_dir
+ if not self.root.exists():
+ logger.info("Downloading the EVD dataset.")
+ self.download()
+ self.pairs = self.index_dataset()
+ if not self.pairs:
+ raise ValueError("No image found!")
+
+ def download(self):
+ data_dir = self.root.parent
+ data_dir.mkdir(exist_ok=True, parents=True)
+ zip_path = data_dir / self.url.rsplit("/", 1)[-1]
+ torch.hub.download_url_to_file(self.url, zip_path)
+ with zipfile.ZipFile(zip_path, 'r') as z:
+ z.extractall(data_dir)
+ os.unlink(zip_path)
+
+
+ def index_dataset(self):
+ sets = sorted([x for x in os.listdir(os.path.join(self.root, '1'))])
+ img_pairs_list = []
+ for s in sets:
+ if s == '.DS_Store':
+ continue
+ img_pairs_list.append(((os.path.join(self.root, '1', s)),
+ (os.path.join(self.root, '2', s)),
+ (os.path.join(self.root, 'h', s.replace('png', 'txt')))))
+ return img_pairs_list
+
+ def __getitem__(self, idx):
+ imgfname1, imgfname2, h_fname = self.pairs[idx]
+ H = read_homography(h_fname)
+ data0 = self.preprocessor(load_image(imgfname1))
+ data1 = self.preprocessor(load_image(imgfname2))
+ H = data1["transform"] @ H @ np.linalg.inv(data0["transform"])
+ pair_name = imgfname1.split('/')[-1].split('.')[0]
+ return {
+ "H_0to1": H.astype(np.float32),
+ "scene": pair_name,
+ "view0": data0,
+ "view1": data1,
+ "idx": idx,
+ "name": pair_name,
+ }
+
+ def __len__(self):
+ return len(self.pairs)
+
+ def get_dataset(self, split):
+ return self
+
+def visualize(args):
+ conf = {
+ "batch_size": 1,
+ "num_workers": 8,
+ "prefetch_factor": 1,
+ }
+ conf = OmegaConf.merge(conf, OmegaConf.from_cli(args.dotlist))
+ dataset = EVD(conf)
+ loader = dataset.get_data_loader("test")
+ logger.info("The dataset has %d elements.", len(loader))
+
+ with fork_rng(seed=dataset.conf.seed):
+ images = []
+ for _, data in zip(range(args.num_items), loader):
+ images.append(
+ [data[f"view{i}"]["image"][0].permute(1, 2, 0) for i in range(2)]
+ )
+ plot_image_grid(images, dpi=args.dpi)
+ plt.tight_layout()
+ plt.show()
+
+
+if __name__ == "__main__":
+ from .. import logger # overwrite the logger
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--num_items", type=int, default=8)
+ parser.add_argument("--dpi", type=int, default=100)
+ parser.add_argument("dotlist", nargs="*")
+ args = parser.parse_intermixed_args()
+ visualize(args)
diff --git a/gluefactory/datasets/wxbs.py b/gluefactory/datasets/wxbs.py
new file mode 100644
index 00000000..443c07ae
--- /dev/null
+++ b/gluefactory/datasets/wxbs.py
@@ -0,0 +1,148 @@
+"""
+Simply load images from a folder or nested folders (does not have any split).
+"""
+
+import argparse
+import logging
+
+import numpy as np
+import torch
+import torchvision
+import os
+
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+import zipfile
+from omegaconf import OmegaConf
+from pathlib import Path
+
+
+from ..settings import DATA_PATH
+from ..utils.image import ImagePreprocessor, load_image
+from .base_dataset import BaseDataset
+from ..utils.tools import fork_rng
+from ..visualization.viz2d import plot_image_grid
+from ..geometry.homography import warp_points
+
+logger = logging.getLogger(__name__)
+
+
+class WxBSDataset(BaseDataset, torch.utils.data.Dataset):
+ """Wide multiple baselines stereo dataset."""
+ url = 'http://cmp.felk.cvut.cz/wbs/datasets/WxBS_v1.1.zip'
+ zip_fname = 'WxBS_v1.1.zip'
+ validation_pairs = ['kyiv_dolltheater2', 'petrzin']
+ default_conf = {
+ "preprocessing": ImagePreprocessor.default_conf,
+ "data_dir": "WxBS",
+ "subset": None,
+ "grayscale": False,
+ }
+ def _init(self, conf):
+ self.preprocessor = ImagePreprocessor(conf.preprocessing)
+ self.root = DATA_PATH / conf.data_dir
+ if not self.root.exists():
+ logger.info("Downloading the WxBS dataset.")
+ self.download()
+ self.pairs = self.index_dataset()
+ if not self.pairs:
+ raise ValueError("No image found!")
+
+ def __len__(self):
+ return len(self.pairs)
+
+ def download(self):
+ data_dir = self.root
+ data_dir.mkdir(exist_ok=True, parents=True)
+ zip_path = data_dir / self.url.rsplit("/", 1)[-1]
+ torch.hub.download_url_to_file(self.url, zip_path)
+ with zipfile.ZipFile(zip_path, 'r') as z:
+ z.extractall(data_dir)
+ os.unlink(zip_path)
+
+ def index_dataset(self):
+ sets = sorted([x for x in os.listdir(self.root) if os.path.isdir(os.path.join(self.root, x))])
+
+ img_pairs_list = []
+ for s in sets[::-1]:
+ if s == '.DS_Store':
+ continue
+ ss = os.path.join(self.root, s)
+ pairs = os.listdir(ss)
+ for p in sorted(pairs):
+ if p == '.DS_Store':
+ continue
+ cur_dir = os.path.join(ss, p)
+ if os.path.isfile(os.path.join(cur_dir, '01.png')):
+ img_pairs_list.append((os.path.join(cur_dir, '01.png'),
+ os.path.join(cur_dir, '02.png'),
+ os.path.join(cur_dir, 'corrs.txt'),
+ os.path.join(cur_dir, 'crossval_errors.txt')))
+ elif os.path.isfile(os.path.join(cur_dir, '01.jpg')):
+ img_pairs_list.append((os.path.join(cur_dir, '01.jpg'),
+ os.path.join(cur_dir, '02.jpg'),
+ os.path.join(cur_dir, 'corrs.txt'),
+ os.path.join(cur_dir, 'crossval_errors.txt')))
+ else:
+ continue
+ return img_pairs_list
+
+ def __getitem__(self, idx):
+ imgfname1, imgfname2, pts_fname, err_fname = self.pairs[idx]
+ data0 = self.preprocessor(load_image(imgfname1))
+ data1 = self.preprocessor(load_image(imgfname2))
+ a = load_image(imgfname1)
+ pts = np.loadtxt(pts_fname)
+ pts[:, :2] = warp_points(pts[:, :2], data0["transform"], False)
+ pts[:, 2:] = warp_points(pts[:, 2:], data1["transform"], False)
+
+ crossval_errors = np.loadtxt(err_fname)
+ pair_name = '/'.join(pts_fname.split('/')[-3:-1]).replace('/', '_')
+ scene_name = '/'.join(pts_fname.split('/')[-3:-2])
+ out = {
+ "pts_0to1": pts,
+ "scene": scene_name,
+ "view0": data0,
+ "view1": data1,
+ "idx": idx,
+ "name": pair_name,
+ "crossval_errors": crossval_errors}
+ return out
+
+ def get_dataset(self, split):
+ assert split in ['val', 'test']
+ return self
+
+
+def visualize(args):
+ conf = {
+ "batch_size": 1,
+ "num_workers": 8,
+ "prefetch_factor": 1,
+ }
+ conf = OmegaConf.merge(conf, OmegaConf.from_cli(args.dotlist))
+ dataset = WxBSDataset(conf)
+ loader = dataset.get_data_loader("test")
+ logger.info("The dataset has %d elements.", len(loader))
+
+ with fork_rng(seed=dataset.conf.seed):
+ images = []
+ for _, data in zip(range(args.num_items), loader):
+ images.append(
+ [data[f"view{i}"]["image"][0].permute(1, 2, 0) for i in range(2)]
+ )
+ plot_image_grid(images, dpi=args.dpi)
+ plt.tight_layout()
+ plt.show()
+
+
+if __name__ == "__main__":
+ from .. import logger # overwrite the logger
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--num_items", type=int, default=8)
+ parser.add_argument("--dpi", type=int, default=100)
+ parser.add_argument("dotlist", nargs="*")
+ args = parser.parse_intermixed_args()
+ visualize(args)
diff --git a/gluefactory/eval/eval_pipeline.py b/gluefactory/eval/eval_pipeline.py
index ac562377..e5384a93 100644
--- a/gluefactory/eval/eval_pipeline.py
+++ b/gluefactory/eval/eval_pipeline.py
@@ -25,8 +25,15 @@ def save_eval(dir, summaries, figures, results):
for k, v in results.items():
arr = np.array(v)
if not np.issubdtype(arr.dtype, np.number):
- arr = arr.astype("object")
- hfile.create_dataset(k, data=arr)
+ if not isinstance(v[0], str):
+ arr = np.array([x.astype(np.float64) for x in v])
+ dt = h5py.special_dtype(vlen=np.float64)
+ hfile.create_dataset(k, data=arr, dtype=dt)
+ else:
+ arr = arr.astype("object")
+ hfile.create_dataset(k, data=arr)
+ else:
+ hfile.create_dataset(k, data=arr)
# just to be safe, not used in practice
for k, v in summaries.items():
hfile.attrs[k] = v
diff --git a/gluefactory/eval/evd.py b/gluefactory/eval/evd.py
new file mode 100644
index 00000000..3b473a77
--- /dev/null
+++ b/gluefactory/eval/evd.py
@@ -0,0 +1,204 @@
+from collections import defaultdict
+from collections.abc import Iterable
+from pathlib import Path
+from pprint import pprint
+
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+from omegaconf import OmegaConf
+from tqdm import tqdm
+
+from ..datasets import get_dataset
+from ..models.cache_loader import CacheLoader
+from ..settings import EVAL_PATH
+from ..utils.export_predictions import export_predictions
+from ..utils.tensor import map_tensor
+from ..utils.tools import AUCMetric
+from ..visualization.viz2d import plot_cumulative
+from .eval_pipeline import EvalPipeline
+from .io import get_eval_parser, load_model, parse_eval_args
+from .utils import (
+ eval_homography_dlt,
+ eval_homography_robust,
+ eval_matches_homography,
+ eval_poses,
+)
+
+
+class EVDPipeline(EvalPipeline):
+ default_conf = {
+ "data": {
+ "batch_size": 1,
+ "name": "evd",
+ "num_workers": 1,
+ "preprocessing": {
+ "resize": 600, # we also resize during eval to have comparable metrics
+ "side": "short",
+ },
+ },
+ "model": {
+ "ground_truth": {
+ "name": None, # remove gt matches
+ }
+ },
+ "eval": {
+ "estimator": "poselib",
+ "ransac_th": 1.0, # -1 runs a bunch of thresholds and selects the best
+ },
+ }
+ export_keys = [
+ "keypoints0",
+ "keypoints1",
+ "keypoint_scores0",
+ "keypoint_scores1",
+ "matches0",
+ "matches1",
+ "matching_scores0",
+ "matching_scores1",
+ ]
+
+ optional_export_keys = [
+ "lines0",
+ "lines1",
+ "orig_lines0",
+ "orig_lines1",
+ "line_matches0",
+ "line_matches1",
+ "line_matching_scores0",
+ "line_matching_scores1",
+ ]
+
+ def _init(self, conf):
+ pass
+
+ @classmethod
+ def get_dataloader(self, data_conf=None):
+ data_conf = data_conf if data_conf else self.default_conf["data"]
+ dataset = get_dataset("evd")(data_conf)
+ return dataset.get_data_loader("test")
+
+ def get_predictions(self, experiment_dir, model=None, overwrite=False):
+ pred_file = experiment_dir / "predictions.h5"
+ if not pred_file.exists() or overwrite:
+ if model is None:
+ model = load_model(self.conf.model, self.conf.checkpoint)
+ with torch.inference_mode():
+ export_predictions(
+ self.get_dataloader(self.conf.data),
+ model,
+ pred_file,
+ keys=self.export_keys,
+ optional_keys=self.optional_export_keys,
+ )
+ return pred_file
+
+ def run_eval(self, loader, pred_file):
+ assert pred_file.exists()
+ results = defaultdict(list)
+
+ conf = self.conf.eval
+
+ test_thresholds = (
+ ([conf.ransac_th] if conf.ransac_th > 0 else [0.5, 1.0, 1.5, 2.0, 2.5, 3.0])
+ if not isinstance(conf.ransac_th, Iterable)
+ else conf.ransac_th
+ )
+ pose_results = defaultdict(lambda: defaultdict(list))
+ cache_loader = CacheLoader({"path": str(pred_file), "collate": None}).eval()
+ for i, data in enumerate(tqdm(loader)):
+ pred = cache_loader(data)
+ # Remove batch dimension
+ data = map_tensor(data, lambda t: torch.squeeze(t, dim=0))
+ # add custom evaluations here
+ if "keypoints0" in pred:
+ results_i = eval_matches_homography(data, pred)
+ results_i = {**results_i, **eval_homography_dlt(data, pred)}
+ else:
+ results_i = {}
+ for th in test_thresholds:
+ pose_results_i = eval_homography_robust(
+ data,
+ pred,
+ {"estimator": conf.estimator, "ransac_th": th},
+ )
+ [pose_results[th][k].append(v) for k, v in pose_results_i.items()]
+
+ # we also store the names for later reference
+ results_i["scenes"] = data["scene"][0]
+ results_i["name"] = data["scene"][0]
+
+ for k, v in results_i.items():
+ results[k].append(v)
+
+ # summarize results as a dict[str, float]
+ # you can also add your custom evaluations here
+ summaries = {}
+ for k, v in results.items():
+ arr = np.array(v)
+ if not np.issubdtype(np.array(v).dtype, np.number):
+ continue
+ summaries[f"m{k}"] = round(np.median(arr), 3)
+
+ auc_ths = [1, 5, 10, 20]
+ best_pose_results, best_th = eval_poses(
+ pose_results, auc_ths=auc_ths, key="H_error_ransac", unit="px"
+ )
+ if "H_error_dlt" in results.keys():
+ dlt_aucs = AUCMetric(auc_ths, results["H_error_dlt"]).compute()
+ for i, ath in enumerate(auc_ths):
+ summaries[f"H_error_dlt@{ath}px"] = dlt_aucs[i]
+
+ results = {**results, **pose_results[best_th]}
+ summaries = {
+ **summaries,
+ **best_pose_results,
+ }
+
+ figures = {
+ "homography_recall": plot_cumulative(
+ {
+ "DLT": results["H_error_dlt"],
+ self.conf.eval.estimator: results["H_error_ransac"],
+ },
+ [0, 20],
+ unit="px",
+ title="Homography ",
+ )
+ }
+
+ return summaries, figures, results
+
+
+if __name__ == "__main__":
+ dataset_name = Path(__file__).stem
+ parser = get_eval_parser()
+ args = parser.parse_intermixed_args()
+
+ default_conf = OmegaConf.create(EVDPipeline.default_conf)
+
+ # mingle paths
+ output_dir = Path(EVAL_PATH, dataset_name)
+ output_dir.mkdir(exist_ok=True, parents=True)
+
+ name, conf = parse_eval_args(
+ dataset_name,
+ args,
+ "configs/",
+ default_conf,
+ )
+
+ experiment_dir = output_dir / name
+ experiment_dir.mkdir(exist_ok=True)
+
+ pipeline = EVDPipeline(conf)
+ s, f, r = pipeline.run(
+ experiment_dir, overwrite=args.overwrite, overwrite_eval=args.overwrite_eval
+ )
+
+ # print results
+ pprint(s)
+ if args.plot:
+ for name, fig in f.items():
+ fig.canvas.manager.set_window_title(name)
+ plt.show()
diff --git a/gluefactory/eval/utils.py b/gluefactory/eval/utils.py
index b89fe792..b6cddad8 100644
--- a/gluefactory/eval/utils.py
+++ b/gluefactory/eval/utils.py
@@ -2,7 +2,7 @@
import torch
from kornia.geometry.homography import find_homography_dlt
-from ..geometry.epipolar import generalized_epi_dist, relative_pose_error
+from ..geometry.epipolar import generalized_epi_dist, relative_pose_error, sym_epipolar_distance
from ..geometry.gt_generation import IGNORE_FEATURE
from ..geometry.homography import homography_corner_error, sym_homography_error
from ..robust_estimators import load_estimator
@@ -22,7 +22,7 @@ def get_matches_scores(kpts0, kpts1, matches0, mscores0):
m0 = matches0 > -1
m1 = matches0[m0]
pts0 = kpts0[m0]
- pts1 = kpts1[m1]
+ pts1 = kpts1[m1.long()]
scores = mscores0[m0]
return pts0, pts1, scores
@@ -69,6 +69,44 @@ def eval_matches_epipolar(data: dict, pred: dict) -> dict:
return results
+def eval_matches_epipolar_via_gt_points(data: dict, pred: dict, conf) -> dict:
+ check_keys_recursive(data, ["view0", "view1", "pts_0to1"])
+ check_keys_recursive(
+ pred, ["keypoints0", "keypoints1", "matches0", "matching_scores0"]
+ )
+
+ kp0, kp1 = pred["keypoints0"], pred["keypoints1"]
+ m0, scores0 = pred["matches0"], pred["matching_scores0"]
+ pts0, pts1, scores = get_matches_scores(kp0, kp1, m0, scores0)
+
+ results = {}
+
+ estimator = load_estimator("fundamental_matrix", conf["estimator"])(conf)
+ data_ = {
+ "m_kpts0": pts0,
+ "m_kpts1": pts1,
+ }
+ est = estimator(data_)
+ success = est["success"] and (len(est["inliers"]) > 0)
+ if not success:
+ results["epi_error"] = [1e6 for i in range(len(data['pts_0to1']))]
+ results["ransac_inl"] = 0
+ results["ransac_inl%"] = 0
+ else:
+ M = est["M_0to1"]
+ inl = est["inliers"].numpy()
+ n_epi_err = sym_epipolar_distance(data['pts_0to1'][:,:2].double(), data['pts_0to1'][:,2:].double(), M.double(), squared=False).detach().cpu().numpy()
+ results["epi_error"] = n_epi_err
+ results["ransac_inl"] = np.sum(inl)
+ results["ransac_inl%"] = np.mean(inl)
+
+ # match metrics
+ results["num_matches"] = pts0.shape[0]
+ results["num_keypoints"] = (kp0.shape[0] + kp1.shape[0]) / 2.0
+
+ return results
+
+
def eval_matches_homography(data: dict, pred: dict) -> dict:
check_keys_recursive(data, ["H_0to1"])
check_keys_recursive(
@@ -224,6 +262,36 @@ def eval_poses(pose_results, auc_ths, key, unit="°"):
return summaries, best_th
+def eval_fundamental_matrices(fm_results, auc_ths, key, unit="°"):
+ pose_aucs = {}
+ best_th = -1
+ for th, results_i in fm_results.items():
+ pair_mean = []
+ for pair_results in results_i[key]:
+ pair_mean.append(AUCMetric(auc_ths, pair_results).compute())
+ pose_aucs[th] = np.array(pair_mean).mean(axis=0)
+ mAAs = {k: np.mean(v) for k, v in pose_aucs.items()}
+ best_th = max(mAAs, key=mAAs.get)
+
+ if len(pose_aucs) > -1:
+ print("Tested ransac setup with following results:")
+ print("AUC", pose_aucs)
+ print("mAA", mAAs)
+ print("best threshold =", best_th)
+
+ summaries = {}
+ for i, ath in enumerate(auc_ths):
+ summaries[f"{key}@{ath}{unit}"] = pose_aucs[best_th][i]
+ summaries[f"{key}_mAA"] = mAAs[best_th]
+
+ for k, v in fm_results[best_th].items():
+ arr = np.array(v)
+ if not np.issubdtype(np.array(v).dtype, np.number):
+ continue
+ summaries[f"m{k}"] = round(np.median(arr), 3)
+ return summaries, best_th
+
+
def get_tp_fp_pts(pred_matches, gt_matches, pred_scores):
"""
Computes the True Positives (TP), False positives (FP), the score associated
diff --git a/gluefactory/eval/wxbs.py b/gluefactory/eval/wxbs.py
new file mode 100644
index 00000000..022bc8b8
--- /dev/null
+++ b/gluefactory/eval/wxbs.py
@@ -0,0 +1,174 @@
+from collections import defaultdict
+from collections.abc import Iterable
+from pathlib import Path
+from pprint import pprint
+
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+from omegaconf import OmegaConf
+from tqdm import tqdm
+
+from ..datasets import get_dataset
+from ..models.cache_loader import CacheLoader
+from ..settings import EVAL_PATH
+from ..utils.export_predictions import export_predictions
+from ..utils.tensor import map_tensor
+from ..utils.tools import AUCMetric
+from ..visualization.viz2d import plot_cumulative
+from .eval_pipeline import EvalPipeline
+from .io import get_eval_parser, load_model, parse_eval_args
+from .utils import (
+ eval_matches_epipolar_via_gt_points,
+ eval_fundamental_matrices
+)
+
+class WxBSPipeline(EvalPipeline):
+ default_conf = {
+ "data": {
+ "batch_size": 1,
+ "name": "wxbs",
+ "num_workers": 1,
+ "preprocessing": {
+ "resize": 600, # we also resize during eval to have comparable metrics
+ "side": "short",
+ },
+ },
+ "model": {
+ "ground_truth": {
+ "name": None, # remove gt matches
+ }
+ },
+ "eval": {
+ "estimator": "poselib",
+ "ransac_th": 1.0, # -1 runs a bunch of thresholds and selects the best
+ },
+ }
+ export_keys = [
+ "keypoints0",
+ "keypoints1",
+ "keypoint_scores0",
+ "keypoint_scores1",
+ "matches0",
+ "matches1",
+ "matching_scores0",
+ "matching_scores1",
+ ]
+
+ optional_export_keys = []
+
+ def _init(self, conf):
+ pass
+
+ @classmethod
+ def get_dataloader(self, data_conf=None):
+ data_conf = data_conf if data_conf else self.default_conf["data"]
+ dataset = get_dataset("wxbs")(data_conf)
+ return dataset.get_data_loader("test")
+
+ def get_predictions(self, experiment_dir, model=None, overwrite=False):
+ pred_file = experiment_dir / "predictions.h5"
+ if not pred_file.exists() or overwrite:
+ if model is None:
+ model = load_model(self.conf.model, self.conf.checkpoint)
+ with torch.inference_mode():
+ export_predictions(
+ self.get_dataloader(self.conf.data),
+ model,
+ pred_file,
+ keys=self.export_keys,
+ optional_keys=self.optional_export_keys,
+ )
+ return pred_file
+
+ def run_eval(self, loader, pred_file):
+ assert pred_file.exists()
+ results = defaultdict(list)
+
+ conf = self.conf.eval
+
+ test_thresholds = (
+ ([conf.ransac_th] if conf.ransac_th > 0 else [0.5, 1.0, 1.5, 2.0, 2.5, 3.0])
+ if not isinstance(conf.ransac_th, Iterable)
+ else conf.ransac_th
+ )
+ pose_results = defaultdict(lambda: defaultdict(list))
+ cache_loader = CacheLoader({"path": str(pred_file), "collate": None}).eval()
+ for i, data in enumerate(tqdm(loader)):
+ pred = cache_loader(data)
+ # Remove batch dimension
+ data = map_tensor(data, lambda t: torch.squeeze(t, dim=0))
+ # add custom evaluations here
+ results_i = {}
+ for th in test_thresholds:
+ pose_results_i = eval_matches_epipolar_via_gt_points(
+ data,
+ pred,
+ {"estimator": conf.estimator, "ransac_th": th},
+ )
+ [pose_results[th][k].append(v) for k, v in pose_results_i.items()]
+
+ # we also store the names for later reference
+ results_i["names"] = data["name"][0]
+ results_i["scenes"] = data["scene"][0]
+
+ for k, v in results_i.items():
+ results[k].append(v)
+
+ # summarize results as a dict[str, float]
+ # you can also add your custom evaluations here
+ summaries = {}
+ for k, v in results.items():
+ arr = np.array(v)
+ if not np.issubdtype(np.array(v).dtype, np.number):
+ continue
+ summaries[f"m{k}"] = round(np.median(arr), 3)
+
+ auc_ths = [1, 5, 10, 20]
+ best_pose_results, best_th = eval_fundamental_matrices(
+ pose_results, auc_ths=auc_ths, key="epi_error", unit="px"
+ )
+
+ results = {**results, **pose_results[best_th]}
+ summaries = {
+ **summaries,
+ **best_pose_results,
+ }
+
+ figures = {}
+
+ return summaries, figures, results
+
+
+if __name__ == "__main__":
+ dataset_name = Path(__file__).stem
+ parser = get_eval_parser()
+ args = parser.parse_intermixed_args()
+
+ default_conf = OmegaConf.create(WxBSPipeline.default_conf)
+
+ # mingle paths
+ output_dir = Path(EVAL_PATH, dataset_name)
+ output_dir.mkdir(exist_ok=True, parents=True)
+
+ name, conf = parse_eval_args(
+ dataset_name,
+ args,
+ "configs/",
+ default_conf,
+ )
+
+ experiment_dir = output_dir / name
+ experiment_dir.mkdir(exist_ok=True)
+
+ pipeline = WxBSPipeline(conf)
+ s, f, r = pipeline.run(
+ experiment_dir, overwrite=args.overwrite, overwrite_eval=args.overwrite_eval
+ )
+
+ # print results
+ pprint(s)
+ if args.plot:
+ for name, fig in f.items():
+ fig.canvas.manager.set_window_title(name)
+ plt.show()
diff --git a/gluefactory/robust_estimators/fundamental_matrix/__init__.py b/gluefactory/robust_estimators/fundamental_matrix/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/gluefactory/robust_estimators/fundamental_matrix/opencv.py b/gluefactory/robust_estimators/fundamental_matrix/opencv.py
new file mode 100644
index 00000000..fc596453
--- /dev/null
+++ b/gluefactory/robust_estimators/fundamental_matrix/opencv.py
@@ -0,0 +1,53 @@
+import cv2
+import torch
+
+from ..base_estimator import BaseEstimator
+
+
+class OpenCVFundamentalMatrixEstimator(BaseEstimator):
+ default_conf = {
+ "ransac_th": 1.0,
+ "options": {"method": "ransac", "max_iters": 30000, "confidence": 0.995},
+ }
+
+ required_data_keys = ["m_kpts0", "m_kpts1"]
+
+ def _init(self, conf):
+ self.solver = {
+ "ransac": cv2.RANSAC,
+ "lmeds": cv2.LMEDS,
+ "rho": cv2.RHO,
+ "usac": cv2.USAC_DEFAULT,
+ "usac_fast": cv2.USAC_FAST,
+ "usac_accurate": cv2.USAC_ACCURATE,
+ "usac_prosac": cv2.USAC_PROSAC,
+ "usac_magsac": cv2.USAC_MAGSAC,
+ }[conf.options.method]
+
+ def _forward(self, data):
+ pts0, pts1 = data["m_kpts0"], data["m_kpts1"]
+
+ try:
+ M, mask = cv2.findFundamentalMat(
+ pts0.numpy(),
+ pts1.numpy(),
+ self.solver,
+ self.conf.ransac_th,
+ maxIters=self.conf.options.max_iters,
+ confidence=self.conf.options.confidence,
+ )
+ success = M is not None
+ except cv2.error:
+ success = False
+ if not success:
+ M = torch.eye(3, device=pts0.device, dtype=pts0.dtype)
+ inl = torch.zeros_like(pts0[:, 0]).bool()
+ else:
+ M = torch.tensor(M).to(pts0)
+ inl = torch.tensor(mask).bool().to(pts0.device)
+
+ return {
+ "success": success,
+ "M_0to1": M,
+ "inliers": inl,
+ }
diff --git a/gluefactory/robust_estimators/fundamental_matrix/poselib.py b/gluefactory/robust_estimators/fundamental_matrix/poselib.py
new file mode 100644
index 00000000..247d7280
--- /dev/null
+++ b/gluefactory/robust_estimators/fundamental_matrix/poselib.py
@@ -0,0 +1,40 @@
+import poselib
+import torch
+from omegaconf import OmegaConf
+
+from ..base_estimator import BaseEstimator
+
+
+class PoseLibFundamentalMatrixEstimator(BaseEstimator):
+ default_conf = {"ransac_th": 1.0, "options": {}}
+
+ required_data_keys = ["m_kpts0", "m_kpts1"]
+
+ def _init(self, conf):
+ pass
+
+ def _forward(self, data):
+ pts0, pts1 = data["m_kpts0"], data["m_kpts1"]
+ M, info = poselib.estimate_fundamental(
+ pts0.detach().cpu().numpy(),
+ pts1.detach().cpu().numpy(),
+ {
+ "max_reproj_error": self.conf.ransac_th,
+ **OmegaConf.to_container(self.conf.options),
+ },
+ )
+ success = M is not None
+ if not success:
+ M = torch.eye(3, device=pts0.device, dtype=pts0.dtype)
+ inl = torch.zeros_like(pts0[:, 0]).bool()
+ else:
+ M = torch.tensor(M).to(pts0)
+ inl = torch.tensor(info["inliers"]).bool().to(pts0.device)
+
+ estimation = {
+ "success": success,
+ "M_0to1": M,
+ "inliers": inl,
+ }
+
+ return estimation