diff --git a/.github/workflows/python-tests.yml b/.github/workflows/python-tests.yml new file mode 100644 index 00000000..05d61c16 --- /dev/null +++ b/.github/workflows/python-tests.yml @@ -0,0 +1,30 @@ +name: Python Tests +on: + push: + branches: + - main + pull_request: + types: [ assigned, opened, synchronize, reopened ] +jobs: + build: + name: Run Python Tests + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v4 + with: + python-version: '3.10' + cache: 'pip' + - name: Install dependencies + run: | + sudo apt-get remove libunwind-14-dev || true + sudo apt-get install -y libceres-dev libeigen3-dev + python -m pip install --upgrade pip + python -m pip install pytest pytest-cov + python -m pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu + python -m pip install -e .[dev] + python -m pip install -e .[extra] + - name: Test with pytest + run: | + set -o pipefail + pytest --junitxml=pytest.xml --cov=gluefactory tests/ \ No newline at end of file diff --git a/assets/boat1.png b/assets/boat1.png new file mode 100644 index 00000000..89cca50e Binary files /dev/null and b/assets/boat1.png differ diff --git a/assets/boat2.png b/assets/boat2.png new file mode 100644 index 00000000..5fb961bc Binary files /dev/null and b/assets/boat2.png differ diff --git a/gluefactory/eval/hpatches.py b/gluefactory/eval/hpatches.py index 8be7b704..bcd799c3 100644 --- a/gluefactory/eval/hpatches.py +++ b/gluefactory/eval/hpatches.py @@ -5,6 +5,7 @@ import matplotlib.pyplot as plt import numpy as np +import torch from omegaconf import OmegaConf from tqdm import tqdm @@ -12,6 +13,7 @@ from ..models.cache_loader import CacheLoader from ..settings import EVAL_PATH from ..utils.export_predictions import export_predictions +from ..utils.tensor import map_tensor from ..utils.tools import AUCMetric from ..visualization.viz2d import plot_cumulative from .eval_pipeline import EvalPipeline @@ -105,9 +107,11 @@ def run_eval(self, loader, pred_file): cache_loader = CacheLoader({"path": str(pred_file), "collate": None}).eval() for i, data in enumerate(tqdm(loader)): pred = cache_loader(data) + # Remove batch dimension + data = map_tensor(data, lambda t: torch.squeeze(t, dim=0)) # add custom evaluations here if "keypoints0" in pred: - results_i = eval_matches_homography(data, pred, {}) + results_i = eval_matches_homography(data, pred) results_i = {**results_i, **eval_homography_dlt(data, pred)} else: results_i = {} diff --git a/gluefactory/eval/utils.py b/gluefactory/eval/utils.py index c6e6f006..b89fe792 100644 --- a/gluefactory/eval/utils.py +++ b/gluefactory/eval/utils.py @@ -1,11 +1,12 @@ -import kornia import numpy as np import torch +from kornia.geometry.homography import find_homography_dlt from ..geometry.epipolar import generalized_epi_dist, relative_pose_error from ..geometry.gt_generation import IGNORE_FEATURE from ..geometry.homography import homography_corner_error, sym_homography_error from ..robust_estimators import load_estimator +from ..utils.tensor import index_batch from ..utils.tools import AUCMetric @@ -26,6 +27,16 @@ def get_matches_scores(kpts0, kpts1, matches0, mscores0): return pts0, pts1, scores +def eval_per_batch_item(data: dict, pred: dict, eval_f, *args, **kwargs): + # Batched data + results = [ + eval_f(data_i, pred_i, *args, **kwargs) + for data_i, pred_i in zip(index_batch(data), index_batch(pred)) + ] + # Return a dictionary of lists with the evaluation of each item + return {k: [r[k] for r in results] for k in results[0].keys()} + + def eval_matches_epipolar(data: dict, pred: dict) -> dict: check_keys_recursive(data, ["view0", "view1", "T_0to1"]) check_keys_recursive( @@ -58,23 +69,25 @@ def eval_matches_epipolar(data: dict, pred: dict) -> dict: return results -def eval_matches_homography(data: dict, pred: dict, conf) -> dict: +def eval_matches_homography(data: dict, pred: dict) -> dict: check_keys_recursive(data, ["H_0to1"]) check_keys_recursive( pred, ["keypoints0", "keypoints1", "matches0", "matching_scores0"] ) H_gt = data["H_0to1"] + if H_gt.ndim > 2: + return eval_per_batch_item(data, pred, eval_matches_homography) + kp0, kp1 = pred["keypoints0"], pred["keypoints1"] m0, scores0 = pred["matches0"], pred["matching_scores0"] pts0, pts1, scores = get_matches_scores(kp0, kp1, m0, scores0) - err = sym_homography_error(pts0, pts1, H_gt[0]) + err = sym_homography_error(pts0, pts1, H_gt) results = {} results["prec@1px"] = (err < 1).float().mean().nan_to_num().item() results["prec@3px"] = (err < 3).float().mean().nan_to_num().item() results["num_matches"] = pts0.shape[0] results["num_keypoints"] = (kp0.shape[0] + kp1.shape[0]) / 2.0 - return results @@ -84,7 +97,7 @@ def eval_relative_pose_robust(data, pred, conf): pred, ["keypoints0", "keypoints1", "matches0", "matching_scores0"] ) - T_gt = data["T_0to1"][0] + T_gt = data["T_0to1"] kp0, kp1 = pred["keypoints0"], pred["keypoints1"] m0, scores0 = pred["matches0"], pred["matching_scores0"] pts0, pts1, scores = get_matches_scores(kp0, kp1, m0, scores0) @@ -107,9 +120,8 @@ def eval_relative_pose_robust(data, pred, conf): else: # R, t, inl = ret M = est["M_0to1"] - R, t = M.numpy() inl = est["inliers"].numpy() - r_error, t_error = relative_pose_error(T_gt, R, t) + t_error, r_error = relative_pose_error(T_gt, M.R, M.t) results["rel_pose_error"] = max(r_error, t_error) results["ransac_inl"] = np.sum(inl) results["ransac_inl%"] = np.mean(inl) @@ -119,6 +131,9 @@ def eval_relative_pose_robust(data, pred, conf): def eval_homography_robust(data, pred, conf): H_gt = data["H_0to1"] + if H_gt.ndim > 2: + return eval_per_batch_item(data, pred, eval_relative_pose_robust, conf) + estimator = load_estimator("homography", conf["estimator"])(conf) data_ = {} @@ -158,24 +173,26 @@ def eval_homography_robust(data, pred, conf): return results -def eval_homography_dlt(data, pred, *args): +def eval_homography_dlt(data, pred): H_gt = data["H_0to1"] H_inf = torch.ones_like(H_gt) * float("inf") kp0, kp1 = pred["keypoints0"], pred["keypoints1"] m0, scores0 = pred["matches0"], pred["matching_scores0"] pts0, pts1, scores = get_matches_scores(kp0, kp1, m0, scores0) + scores = scores.to(pts0) results = {} try: - Hdlt = kornia.geometry.homography.find_homography_dlt( - pts0[None], pts1[None], scores[None].to(pts0) - )[0] + if H_gt.ndim == 2: + pts0, pts1, scores = pts0[None], pts1[None], scores[None] + h_dlt = find_homography_dlt(pts0, pts1, scores) + if H_gt.ndim == 2: + h_dlt = h_dlt[0] except AssertionError: - Hdlt = H_inf + h_dlt = H_inf - error_dlt = homography_corner_error(Hdlt, H_gt, data["view0"]["image_size"]) + error_dlt = homography_corner_error(h_dlt, H_gt, data["view0"]["image_size"]) results["H_error_dlt"] = error_dlt.item() - return results diff --git a/gluefactory/geometry/epipolar.py b/gluefactory/geometry/epipolar.py index 7e1507c0..1f7bb9ce 100644 --- a/gluefactory/geometry/epipolar.py +++ b/gluefactory/geometry/epipolar.py @@ -1,4 +1,3 @@ -import numpy as np import torch from .utils import skew_symmetric, to_homogeneous @@ -124,39 +123,33 @@ def decompose_essential_matrix(E): # pose errors -# TODO: port to torch and batch +# TODO: test for batched data def angle_error_mat(R1, R2): - cos = (np.trace(np.dot(R1.T, R2)) - 1) / 2 - cos = np.clip(cos, -1.0, 1.0) # numercial errors can make it out of bounds - return np.rad2deg(np.abs(np.arccos(cos))) + cos = (torch.trace(torch.einsum("...ij, ...jk -> ...ik", R1.T, R2)) - 1) / 2 + cos = torch.clip(cos, -1.0, 1.0) # numerical errors can make it out of bounds + return torch.rad2deg(torch.abs(torch.arccos(cos))) -def angle_error_vec(v1, v2): - n = np.linalg.norm(v1) * np.linalg.norm(v2) - return np.rad2deg(np.arccos(np.clip(np.dot(v1, v2) / n, -1.0, 1.0))) +def angle_error_vec(v1, v2, eps=1e-10): + n = torch.clip(v1.norm(dim=-1) * v2.norm(dim=-1), min=eps) + v1v2 = (v1 * v2).sum(dim=-1) # dot product in the last dimension + return torch.rad2deg(torch.arccos(torch.clip(v1v2 / n, -1.0, 1.0))) -def compute_pose_error(T_0to1, R, t): - R_gt = T_0to1[:3, :3] - t_gt = T_0to1[:3, 3] - error_t = angle_error_vec(t, t_gt) - error_t = np.minimum(error_t, 180 - error_t) # ambiguity of E estimation - error_R = angle_error_mat(R, R_gt) - return error_t, error_R - +def relative_pose_error(T_0to1, R, t, ignore_gt_t_thr=0.0, eps=1e-10): + if isinstance(T_0to1, torch.Tensor): + R_gt, t_gt = T_0to1[:3, :3], T_0to1[:3, 3] + else: + R_gt, t_gt = T_0to1.R, T_0to1.t + R_gt, t_gt = torch.squeeze(R_gt), torch.squeeze(t_gt) -def relative_pose_error(T_0to1, R, t, ignore_gt_t_thr=0.0): # angle error between 2 vectors - R_gt, t_gt = T_0to1.numpy() - n = np.linalg.norm(t) * np.linalg.norm(t_gt) - t_err = np.rad2deg(np.arccos(np.clip(np.dot(t, t_gt) / n, -1.0, 1.0))) - t_err = np.minimum(t_err, 180 - t_err) # handle E ambiguity - if np.linalg.norm(t_gt) < ignore_gt_t_thr: # pure rotation is challenging + t_err = angle_error_vec(t, t_gt, eps) + t_err = torch.minimum(t_err, 180 - t_err) # handle E ambiguity + if t_gt.norm() < ignore_gt_t_thr: # pure rotation is challenging t_err = 0 # angle error between 2 rotation matrices - cos = (np.trace(np.dot(R.T, R_gt)) - 1) / 2 - cos = np.clip(cos, -1.0, 1.0) # handle numercial errors - R_err = np.rad2deg(np.abs(np.arccos(cos))) + r_err = angle_error_mat(R, R_gt) - return t_err, R_err + return t_err, r_err diff --git a/gluefactory/geometry/homography.py b/gluefactory/geometry/homography.py index 3acb9307..f87b9f90 100644 --- a/gluefactory/geometry/homography.py +++ b/gluefactory/geometry/homography.py @@ -164,7 +164,8 @@ def warp_points_torch(points, H, inverse=True): The inverse is used to be coherent with tf.contrib.image.transform Arguments: points: batched list of N points, shape (B, N, 2). - homography: batched or not (shapes (B, 3, 3) and (3, 3) respectively). + H: batched or not (shapes (B, 3, 3) and (3, 3) respectively). + inverse: Whether to multiply the points by H or the inverse of H Returns: a Tensor of shape (B, N, 2) containing the new coordinates of the warps. """ @@ -333,7 +334,7 @@ def sym_homography_error_all(kpts0, kpts1, H): def homography_corner_error(T, T_gt, image_size): - W, H = image_size[:, 0], image_size[:, 1] + W, H = image_size[..., 0], image_size[..., 1] corners0 = torch.Tensor([[0, 0], [W, 0], [W, H], [0, H]]).float().to(T) corners1_gt = from_homogeneous(to_homogeneous(corners0) @ T_gt.transpose(-1, -2)) corners1 = from_homogeneous(to_homogeneous(corners0) @ T.transpose(-1, -2)) diff --git a/gluefactory/geometry/utils.py b/gluefactory/geometry/utils.py index eec330a9..4734e341 100644 --- a/gluefactory/geometry/utils.py +++ b/gluefactory/geometry/utils.py @@ -23,6 +23,7 @@ def from_homogeneous(points, eps=0.0): """Remove the homogeneous dimension of N-dimensional points. Args: points: torch.Tensor or numpy.ndarray with size (..., N+1). + eps: Epsilon value to prevent zero division. Returns: A torch.Tensor or numpy ndarray with size (..., N). """ diff --git a/gluefactory/models/matchers/gluestick.py b/gluefactory/models/matchers/gluestick.py index 0187e0c3..e16a8a52 100644 --- a/gluefactory/models/matchers/gluestick.py +++ b/gluefactory/models/matchers/gluestick.py @@ -119,7 +119,7 @@ def _init(self, conf): "Loading GlueStick model from " f'"{self.url.format(conf.version)}"' ) state_dict = torch.hub.load_state_dict_from_url( - self.url.format(conf.version), file_name=fname + self.url.format(conf.version), file_name=fname, map_location="cpu" ) if "model" in state_dict: diff --git a/gluefactory/models/matchers/lightglue_pretrained.py b/gluefactory/models/matchers/lightglue_pretrained.py index b23976db..985a0ce1 100644 --- a/gluefactory/models/matchers/lightglue_pretrained.py +++ b/gluefactory/models/matchers/lightglue_pretrained.py @@ -17,7 +17,7 @@ class LightGlue(BaseModel): def _init(self, conf): dconf = OmegaConf.to_container(conf) - self.net = LightGlue_(dconf.pop("features"), **dconf).cuda() + self.net = LightGlue_(dconf.pop("features"), **dconf) self.set_initialized() def _forward(self, data): diff --git a/gluefactory/robust_estimators/homography/homography_est.py b/gluefactory/robust_estimators/homography/homography_est.py index 510650c4..780011ee 100644 --- a/gluefactory/robust_estimators/homography/homography_est.py +++ b/gluefactory/robust_estimators/homography/homography_est.py @@ -7,6 +7,7 @@ ransac_point_line_homography, ) +from ...utils.tensor import batch_to_numpy from ..base_estimator import BaseEstimator @@ -50,19 +51,20 @@ def _init(self, conf): pass def _forward(self, data): + feat = data["m_kpts0"] if "m_kpts0" in data else data["m_lines0"] + data = batch_to_numpy(data) m_features = { - "kpts0": data["m_kpts1"].numpy() if "m_kpts1" in data else None, - "kpts1": data["m_kpts0"].numpy() if "m_kpts0" in data else None, - "lines0": data["m_lines1"].numpy() if "m_lines1" in data else None, - "lines1": data["m_lines0"].numpy() if "m_lines0" in data else None, + "kpts0": data["m_kpts1"] if "m_kpts1" in data else None, + "kpts1": data["m_kpts0"] if "m_kpts0" in data else None, + "lines0": data["m_lines1"] if "m_lines1" in data else None, + "lines1": data["m_lines0"] if "m_lines0" in data else None, } - feat = data["m_kpts0"] if "m_kpts0" in data else data["m_lines0"] M = H_estimation_hybrid(**m_features, tol_px=self.conf.ransac_th) success = M is not None if not success: M = torch.eye(3, device=feat.device, dtype=feat.dtype) else: - M = torch.tensor(M).to(feat) + M = torch.from_numpy(M).to(feat) estimation = { "success": success, diff --git a/gluefactory/robust_estimators/homography/poselib.py b/gluefactory/robust_estimators/homography/poselib.py index e99e9493..6aa71496 100644 --- a/gluefactory/robust_estimators/homography/poselib.py +++ b/gluefactory/robust_estimators/homography/poselib.py @@ -16,8 +16,8 @@ def _init(self, conf): def _forward(self, data): pts0, pts1 = data["m_kpts0"], data["m_kpts1"] M, info = poselib.estimate_homography( - pts0.numpy(), - pts1.numpy(), + pts0.detach().cpu().numpy(), + pts1.detach().cpu().numpy(), { "max_reproj_error": self.conf.ransac_th, **OmegaConf.to_container(self.conf.options), diff --git a/gluefactory/utils/tensor.py b/gluefactory/utils/tensor.py index f31bb580..d0a8ca50 100644 --- a/gluefactory/utils/tensor.py +++ b/gluefactory/utils/tensor.py @@ -40,3 +40,9 @@ def rbd(data: dict) -> dict: k: v[0] if isinstance(v, (torch.Tensor, np.ndarray, list)) else v for k, v in data.items() } + + +def index_batch(tensor_dict): + batch_size = len(next(iter(tensor_dict.values()))) + for i in range(batch_size): + yield map_tensor(tensor_dict, lambda t: t[i]) diff --git a/pyproject.toml b/pyproject.toml index 5185a753..b740a956 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,12 +38,12 @@ urls = {Repository = "https://github.com/cvg/glue-factory"} [project.optional-dependencies] extra = [ "pycolmap", - "poselib @ git+https://github.com/PoseLib/PoseLib.git", - "pytlsd @ git+https://github.com/iago-suarez/pytlsd.git", + "poselib @ git+https://github.com/PoseLib/PoseLib.git@9c8f3ca1baba69e19726cc7caded574873ec1f9e", + "pytlsd @ git+https://github.com/iago-suarez/pytlsd.git@v0.0.5", "deeplsd @ git+https://github.com/cvg/DeepLSD.git", - "homography_est @ git+https://github.com/rpautrat/homography_est.git", + "homography_est @ git+https://github.com/rpautrat/homography_est.git@17b200d528e6aa8ac61a878a29265bf5f9d36c41", ] -dev = ["black", "flake8", "isort"] +dev = ["black", "flake8", "isort", "parameterized"] [tool.setuptools.packages.find] include = ["gluefactory*"] diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_eval_utils.py b/tests/test_eval_utils.py new file mode 100644 index 00000000..fead8964 --- /dev/null +++ b/tests/test_eval_utils.py @@ -0,0 +1,88 @@ +import unittest + +import torch + +from gluefactory.eval.utils import eval_matches_homography +from gluefactory.geometry.homography import warp_points_torch + + +class TestEvalUtils(unittest.TestCase): + @staticmethod + def default_pts(): + return torch.tensor( + [ + [10.0, 10.0], + [10.0, 20.0], + [20.0, 20.0], + [20.0, 10.0], + ] + ) + + @staticmethod + def default_pred(kps0, kps1): + return { + "keypoints0": kps0, + "keypoints1": kps1, + "matches0": torch.arange(len(kps0)), + "matching_scores0": torch.ones(len(kps1)), + } + + def test_eval_matches_homography_trivial(self): + data = {"H_0to1": torch.eye(3)} + kps = self.default_pts() + pred = self.default_pred(kps, kps) + + results = eval_matches_homography(data, pred) + + self.assertEqual(results["prec@1px"], 1) + self.assertEqual(results["prec@3px"], 1) + self.assertEqual(results["num_matches"], 4) + self.assertEqual(results["num_keypoints"], 4) + + def test_eval_matches_homography_real(self): + data = {"H_0to1": torch.tensor([[1.5, 0.2, 21], [-0.3, 1.6, 33], [0, 0, 1.0]])} + kps0 = self.default_pts() + kps1 = warp_points_torch(kps0, data["H_0to1"], inverse=False) + pred = self.default_pred(kps0, kps1) + + results = eval_matches_homography(data, pred) + + self.assertEqual(results["prec@1px"], 1) + self.assertEqual(results["prec@3px"], 1) + + def test_eval_matches_homography_real_outliers(self): + data = {"H_0to1": torch.tensor([[1.5, 0.2, 21], [-0.3, 1.6, 33], [0, 0, 1.0]])} + kps0 = self.default_pts() + kps0 = torch.cat([kps0, torch.tensor([[5.0, 5.0]])]) + kps1 = warp_points_torch(kps0, data["H_0to1"], inverse=False) + # Move one keypoint 1.5 pixels away in x and y + kps1[-1] += 1.5 + pred = self.default_pred(kps0, kps1) + + results = eval_matches_homography(data, pred) + self.assertAlmostEqual(results["prec@1px"], 0.8) + self.assertAlmostEqual(results["prec@3px"], 1.0) + + def test_eval_matches_homography_batched(self): + H0 = torch.tensor([[1.5, 0.2, 21], [-0.3, 1.6, 33], [0, 0, 1.0]]) + H1 = torch.tensor([[0.7, 0.1, -5], [-0.1, 0.65, 13], [0, 0, 1.0]]) + data = {"H_0to1": torch.stack([H0, H1])} + kps0 = torch.stack([self.default_pts(), self.default_pts().flip(0)]) + kps1 = warp_points_torch(kps0, data["H_0to1"], inverse=False) + # In the first element of the batch there is one outlier + kps1[0, -1] += 5 + matches0 = torch.stack([torch.arange(4), torch.arange(4)]) + # In the second element of the batch there is only 2 matches + matches0[1, :2] = -1 + pred = { + "keypoints0": kps0, + "keypoints1": kps1, + "matches0": matches0, + "matching_scores0": torch.ones_like(matches0), + } + + results = eval_matches_homography(data, pred) + self.assertAlmostEqual(results["prec@1px"][0], 0.75) + self.assertAlmostEqual(results["prec@1px"][1], 1.0) + self.assertAlmostEqual(results["num_matches"][0], 4) + self.assertAlmostEqual(results["num_matches"][1], 2) diff --git a/tests/test_integration.py b/tests/test_integration.py new file mode 100644 index 00000000..e459ada5 --- /dev/null +++ b/tests/test_integration.py @@ -0,0 +1,132 @@ +import unittest +from collections import namedtuple +from os.path import splitext + +import cv2 +import matplotlib.pyplot as plt +import torch.cuda +from kornia import image_to_tensor +from omegaconf import OmegaConf +from parameterized import parameterized +from torch import Tensor + +from gluefactory import logger +from gluefactory.eval.utils import ( + eval_homography_dlt, + eval_homography_robust, + eval_matches_homography, +) +from gluefactory.models.two_view_pipeline import TwoViewPipeline +from gluefactory.settings import root +from gluefactory.utils.image import ImagePreprocessor +from gluefactory.utils.tensor import map_tensor +from gluefactory.utils.tools import set_seed +from gluefactory.visualization.viz2d import ( + plot_color_line_matches, + plot_images, + plot_matches, +) + + +def create_input_data(cv_img0, cv_img1, device): + img0 = image_to_tensor(cv_img0).float() / 255 + img1 = image_to_tensor(cv_img1).float() / 255 + ip = ImagePreprocessor({}) + data = {"view0": ip(img0), "view1": ip(img1)} + data = map_tensor( + data, + lambda t: t[None].to(device) + if isinstance(t, Tensor) + else torch.from_numpy(t)[None].to(device), + ) + return data + + +ExpectedResults = namedtuple("ExpectedResults", ("num_matches", "prec3px", "h_error")) + + +class TestIntegration(unittest.TestCase): + methods_to_test = [ + ("superpoint+NN.yaml", "poselib", ExpectedResults(1300, 0.8, 1.0)), + ("superpoint-open+NN.yaml", "poselib", ExpectedResults(1300, 0.8, 1.0)), + ( + "superpoint+lsd+gluestick.yaml", + "homography_est", + ExpectedResults(1300, 0.8, 1.0), + ), + ( + "superpoint+lightglue-official.yaml", + "poselib", + ExpectedResults(1300, 0.8, 1.0), + ), + ] + + visualize = False + + @parameterized.expand(methods_to_test) + @torch.no_grad() + def test_real_homography(self, conf_file, estimator, exp_results): + set_seed(0) + model_path = root / "gluefactory" / "configs" / conf_file + img_path0 = root / "assets" / "boat1.png" + img_path1 = root / "assets" / "boat2.png" + h_gt = torch.tensor( + [ + [0.85799, 0.21669, 9.4839], + [-0.21177, 0.85855, 130.48], + [1.5015e-06, 9.2033e-07, 1], + ] + ) + + device = "cuda" if torch.cuda.is_available() else "cpu" + gs = TwoViewPipeline(OmegaConf.load(model_path).model).to(device).eval() + + cv_img0, cv_img1 = cv2.imread(str(img_path0)), cv2.imread(str(img_path1)) + data = create_input_data(cv_img0, cv_img1, device) + pred = gs(data) + pred = map_tensor( + pred, lambda t: torch.squeeze(t, dim=0) if isinstance(t, Tensor) else t + ) + data["H_0to1"] = h_gt.to(device) + data["H_1to0"] = torch.linalg.inv(h_gt).to(device) + + results = eval_matches_homography(data, pred) + results = {**results, **eval_homography_dlt(data, pred)} + + results = { + **results, + **eval_homography_robust( + data, + pred, + {"estimator": estimator}, + ), + } + + logger.info(results) + self.assertGreater(results["num_matches"], exp_results.num_matches) + self.assertGreater(results["prec@3px"], exp_results.prec3px) + self.assertLess(results["H_error_ransac"], exp_results.h_error) + + if self.visualize: + pred = map_tensor( + pred, lambda t: t.cpu().numpy() if isinstance(t, Tensor) else t + ) + kp0, kp1 = pred["keypoints0"], pred["keypoints1"] + m0 = pred["matches0"] + valid0 = m0 != -1 + kpm0, kpm1 = kp0[valid0], kp1[m0[valid0]] + + plot_images([cv_img0, cv_img1]) + plot_matches(kpm0, kpm1, a=0.0) + plt.savefig(f"{splitext(conf_file)[0]}_point_matches.svg") + + if "lines0" in pred and "lines1" in pred: + lines0, lines1 = pred["lines0"], pred["lines1"] + lm0 = pred["line_matches0"] + lvalid0 = lm0 != -1 + linem0, linem1 = lines0[lvalid0], lines1[lm0[lvalid0]] + + plot_images([cv_img0, cv_img1]) + plot_color_line_matches([linem0, linem1]) + plt.savefig(f"{splitext(conf_file)[0]}_line_matches.svg") + plt.show()