Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding basic CI and fixing minor problems #17

Merged
merged 4 commits into from
Oct 18, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions .github/workflows/python-tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
name: Python Tests
on:
push:
branches:
- main
pull_request:
types: [ assigned, opened, synchronize, reopened ]
jobs:
build:
name: Run Python Tests
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: '3.10'
cache: 'pip'
- name: Install dependencies
run: |
sudo apt-get remove libunwind-14-dev || true
sudo apt-get install -y libceres-dev libeigen3-dev
python -m pip install --upgrade pip
python -m pip install pytest pytest-cov
python -m pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu
python -m pip install -e .[dev]
python -m pip install -e .[extra]
- name: Test with pytest
run: |
set -o pipefail
pytest --junitxml=pytest.xml --cov=gluefactory tests/
Binary file added assets/boat1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/boat2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
6 changes: 5 additions & 1 deletion gluefactory/eval/hpatches.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@

import matplotlib.pyplot as plt
import numpy as np
import torch
from omegaconf import OmegaConf
from tqdm import tqdm

from ..datasets import get_dataset
from ..models.cache_loader import CacheLoader
from ..settings import EVAL_PATH
from ..utils.export_predictions import export_predictions
from ..utils.tensor import map_tensor
from ..utils.tools import AUCMetric
from ..visualization.viz2d import plot_cumulative
from .eval_pipeline import EvalPipeline
Expand Down Expand Up @@ -105,9 +107,11 @@ def run_eval(self, loader, pred_file):
cache_loader = CacheLoader({"path": str(pred_file), "collate": None}).eval()
for i, data in enumerate(tqdm(loader)):
pred = cache_loader(data)
# Remove batch dimension
data = map_tensor(data, lambda t: torch.squeeze(t, dim=0))
# add custom evaluations here
if "keypoints0" in pred:
results_i = eval_matches_homography(data, pred, {})
results_i = eval_matches_homography(data, pred)
results_i = {**results_i, **eval_homography_dlt(data, pred)}
else:
results_i = {}
Expand Down
45 changes: 31 additions & 14 deletions gluefactory/eval/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import kornia
import numpy as np
import torch
from kornia.geometry.homography import find_homography_dlt

from ..geometry.epipolar import generalized_epi_dist, relative_pose_error
from ..geometry.gt_generation import IGNORE_FEATURE
from ..geometry.homography import homography_corner_error, sym_homography_error
from ..robust_estimators import load_estimator
from ..utils.tensor import index_batch
from ..utils.tools import AUCMetric


Expand All @@ -26,6 +27,16 @@ def get_matches_scores(kpts0, kpts1, matches0, mscores0):
return pts0, pts1, scores


def eval_per_batch_item(data: dict, pred: dict, eval_f, *args, **kwargs):
# Batched data
results = [
eval_f(data_i, pred_i, *args, **kwargs)
for data_i, pred_i in zip(index_batch(data), index_batch(pred))
]
# Return a dictionary of lists with the evaluation of each item
return {k: [r[k] for r in results] for k in results[0].keys()}


def eval_matches_epipolar(data: dict, pred: dict) -> dict:
check_keys_recursive(data, ["view0", "view1", "T_0to1"])
check_keys_recursive(
Expand Down Expand Up @@ -58,23 +69,25 @@ def eval_matches_epipolar(data: dict, pred: dict) -> dict:
return results


def eval_matches_homography(data: dict, pred: dict, conf) -> dict:
def eval_matches_homography(data: dict, pred: dict) -> dict:
check_keys_recursive(data, ["H_0to1"])
check_keys_recursive(
pred, ["keypoints0", "keypoints1", "matches0", "matching_scores0"]
)

H_gt = data["H_0to1"]
if H_gt.ndim > 2:
return eval_per_batch_item(data, pred, eval_matches_homography)

kp0, kp1 = pred["keypoints0"], pred["keypoints1"]
m0, scores0 = pred["matches0"], pred["matching_scores0"]
pts0, pts1, scores = get_matches_scores(kp0, kp1, m0, scores0)
err = sym_homography_error(pts0, pts1, H_gt[0])
err = sym_homography_error(pts0, pts1, H_gt)
results = {}
results["prec@1px"] = (err < 1).float().mean().nan_to_num().item()
results["prec@3px"] = (err < 3).float().mean().nan_to_num().item()
results["num_matches"] = pts0.shape[0]
results["num_keypoints"] = (kp0.shape[0] + kp1.shape[0]) / 2.0

return results


Expand All @@ -84,7 +97,7 @@ def eval_relative_pose_robust(data, pred, conf):
pred, ["keypoints0", "keypoints1", "matches0", "matching_scores0"]
)

T_gt = data["T_0to1"][0]
T_gt = data["T_0to1"]
kp0, kp1 = pred["keypoints0"], pred["keypoints1"]
m0, scores0 = pred["matches0"], pred["matching_scores0"]
pts0, pts1, scores = get_matches_scores(kp0, kp1, m0, scores0)
Expand All @@ -107,9 +120,8 @@ def eval_relative_pose_robust(data, pred, conf):
else:
# R, t, inl = ret
M = est["M_0to1"]
R, t = M.numpy()
inl = est["inliers"].numpy()
r_error, t_error = relative_pose_error(T_gt, R, t)
t_error, r_error = relative_pose_error(T_gt, M.R, M.t)
results["rel_pose_error"] = max(r_error, t_error)
results["ransac_inl"] = np.sum(inl)
results["ransac_inl%"] = np.mean(inl)
Expand All @@ -119,6 +131,9 @@ def eval_relative_pose_robust(data, pred, conf):

def eval_homography_robust(data, pred, conf):
H_gt = data["H_0to1"]
if H_gt.ndim > 2:
return eval_per_batch_item(data, pred, eval_relative_pose_robust, conf)

estimator = load_estimator("homography", conf["estimator"])(conf)

data_ = {}
Expand Down Expand Up @@ -158,24 +173,26 @@ def eval_homography_robust(data, pred, conf):
return results


def eval_homography_dlt(data, pred, *args):
def eval_homography_dlt(data, pred):
H_gt = data["H_0to1"]
H_inf = torch.ones_like(H_gt) * float("inf")

kp0, kp1 = pred["keypoints0"], pred["keypoints1"]
m0, scores0 = pred["matches0"], pred["matching_scores0"]
pts0, pts1, scores = get_matches_scores(kp0, kp1, m0, scores0)
scores = scores.to(pts0)
results = {}
try:
Hdlt = kornia.geometry.homography.find_homography_dlt(
pts0[None], pts1[None], scores[None].to(pts0)
)[0]
if H_gt.ndim == 2:
pts0, pts1, scores = pts0[None], pts1[None], scores[None]
h_dlt = find_homography_dlt(pts0, pts1, scores)
if H_gt.ndim == 2:
h_dlt = h_dlt[0]
except AssertionError:
Hdlt = H_inf
h_dlt = H_inf

error_dlt = homography_corner_error(Hdlt, H_gt, data["view0"]["image_size"])
error_dlt = homography_corner_error(h_dlt, H_gt, data["view0"]["image_size"])
results["H_error_dlt"] = error_dlt.item()

return results


Expand Down
37 changes: 18 additions & 19 deletions gluefactory/geometry/epipolar.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import numpy as np
import torch

from .utils import skew_symmetric, to_homogeneous
Expand Down Expand Up @@ -124,39 +123,39 @@ def decompose_essential_matrix(E):


# pose errors
# TODO: port to torch and batch
# TODO: test for batched data
def angle_error_mat(R1, R2):
cos = (np.trace(np.dot(R1.T, R2)) - 1) / 2
cos = np.clip(cos, -1.0, 1.0) # numercial errors can make it out of bounds
return np.rad2deg(np.abs(np.arccos(cos)))
cos = (torch.trace(torch.einsum("...ij, ...jk -> ...ik", R1.T, R2)) - 1) / 2
cos = torch.clip(cos, -1.0, 1.0) # numerical errors can make it out of bounds
return torch.rad2deg(torch.abs(torch.arccos(cos)))


def angle_error_vec(v1, v2):
n = np.linalg.norm(v1) * np.linalg.norm(v2)
return np.rad2deg(np.arccos(np.clip(np.dot(v1, v2) / n, -1.0, 1.0)))
def angle_error_vec(v1, v2, eps=1e-10):
n = torch.clip(v1.norm(dim=-1) * v2.norm(dim=-1), min=eps)
v1v2 = (v1 * v2).sum(dim=-1) # dot product in the last dimension
return torch.rad2deg(torch.arccos(torch.clip(v1v2 / n, -1.0, 1.0)))


def compute_pose_error(T_0to1, R, t):
R_gt = T_0to1[:3, :3]
t_gt = T_0to1[:3, 3]
error_t = angle_error_vec(t, t_gt)
error_t = np.minimum(error_t, 180 - error_t) # ambiguity of E estimation
error_t = torch.minimum(error_t, 180 - error_t) # ambiguity of E estimation
error_R = angle_error_mat(R, R_gt)
return error_t, error_R


def relative_pose_error(T_0to1, R, t, ignore_gt_t_thr=0.0):
def relative_pose_error(T_0to1, R, t, ignore_gt_t_thr=0.0, eps=1e-10):
sarlinpe marked this conversation as resolved.
Show resolved Hide resolved
R_gt, t_gt = T_0to1.R, T_0to1.t
R_gt, t_gt = torch.squeeze(R_gt), torch.squeeze(t_gt)

# angle error between 2 vectors
R_gt, t_gt = T_0to1.numpy()
n = np.linalg.norm(t) * np.linalg.norm(t_gt)
t_err = np.rad2deg(np.arccos(np.clip(np.dot(t, t_gt) / n, -1.0, 1.0)))
t_err = np.minimum(t_err, 180 - t_err) # handle E ambiguity
if np.linalg.norm(t_gt) < ignore_gt_t_thr: # pure rotation is challenging
t_err = angle_error_vec(t, t_gt, eps)
t_err = torch.minimum(t_err, 180 - t_err) # handle E ambiguity
if t_gt.norm() < ignore_gt_t_thr: # pure rotation is challenging
t_err = 0

# angle error between 2 rotation matrices
cos = (np.trace(np.dot(R.T, R_gt)) - 1) / 2
cos = np.clip(cos, -1.0, 1.0) # handle numercial errors
R_err = np.rad2deg(np.abs(np.arccos(cos)))
r_err = angle_error_mat(R, R_gt)

return t_err, R_err
return t_err, r_err
5 changes: 3 additions & 2 deletions gluefactory/geometry/homography.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,8 @@ def warp_points_torch(points, H, inverse=True):
The inverse is used to be coherent with tf.contrib.image.transform
Arguments:
points: batched list of N points, shape (B, N, 2).
homography: batched or not (shapes (B, 3, 3) and (3, 3) respectively).
H: batched or not (shapes (B, 3, 3) and (3, 3) respectively).
inverse: Whether to multiply the points by H or the inverse of H
Returns: a Tensor of shape (B, N, 2) containing the new coordinates of the warps.
"""

Expand Down Expand Up @@ -333,7 +334,7 @@ def sym_homography_error_all(kpts0, kpts1, H):


def homography_corner_error(T, T_gt, image_size):
W, H = image_size[:, 0], image_size[:, 1]
W, H = image_size[..., 0], image_size[..., 1]
corners0 = torch.Tensor([[0, 0], [W, 0], [W, H], [0, H]]).float().to(T)
corners1_gt = from_homogeneous(to_homogeneous(corners0) @ T_gt.transpose(-1, -2))
corners1 = from_homogeneous(to_homogeneous(corners0) @ T.transpose(-1, -2))
Expand Down
1 change: 1 addition & 0 deletions gluefactory/geometry/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def from_homogeneous(points, eps=0.0):
"""Remove the homogeneous dimension of N-dimensional points.
Args:
points: torch.Tensor or numpy.ndarray with size (..., N+1).
eps: Epsilon value to prevent zero division.
Returns:
A torch.Tensor or numpy ndarray with size (..., N).
"""
Expand Down
2 changes: 1 addition & 1 deletion gluefactory/models/matchers/gluestick.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def _init(self, conf):
"Loading GlueStick model from " f'"{self.url.format(conf.version)}"'
)
state_dict = torch.hub.load_state_dict_from_url(
self.url.format(conf.version), file_name=fname
self.url.format(conf.version), file_name=fname, map_location="cpu"
)

if "model" in state_dict:
Expand Down
2 changes: 1 addition & 1 deletion gluefactory/models/matchers/lightglue_pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class LightGlue(BaseModel):

def _init(self, conf):
dconf = OmegaConf.to_container(conf)
self.net = LightGlue_(dconf.pop("features"), **dconf).cuda()
self.net = LightGlue_(dconf.pop("features"), **dconf)
self.set_initialized()

def _forward(self, data):
Expand Down
14 changes: 8 additions & 6 deletions gluefactory/robust_estimators/homography/homography_est.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
ransac_point_line_homography,
)

from ...utils.tensor import batch_to_numpy
from ..base_estimator import BaseEstimator


Expand Down Expand Up @@ -50,19 +51,20 @@ def _init(self, conf):
pass

def _forward(self, data):
feat = data["m_kpts0"] if "m_kpts0" in data else data["m_lines0"]
data = batch_to_numpy(data)
m_features = {
"kpts0": data["m_kpts1"].numpy() if "m_kpts1" in data else None,
"kpts1": data["m_kpts0"].numpy() if "m_kpts0" in data else None,
"lines0": data["m_lines1"].numpy() if "m_lines1" in data else None,
"lines1": data["m_lines0"].numpy() if "m_lines0" in data else None,
"kpts0": data["m_kpts1"] if "m_kpts1" in data else None,
"kpts1": data["m_kpts0"] if "m_kpts0" in data else None,
"lines0": data["m_lines1"] if "m_lines1" in data else None,
"lines1": data["m_lines0"] if "m_lines0" in data else None,
}
feat = data["m_kpts0"] if "m_kpts0" in data else data["m_lines0"]
M = H_estimation_hybrid(**m_features, tol_px=self.conf.ransac_th)
success = M is not None
if not success:
M = torch.eye(3, device=feat.device, dtype=feat.dtype)
else:
M = torch.tensor(M).to(feat)
M = torch.from_numpy(M).to(feat)

estimation = {
"success": success,
Expand Down
4 changes: 2 additions & 2 deletions gluefactory/robust_estimators/homography/poselib.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ def _init(self, conf):
def _forward(self, data):
pts0, pts1 = data["m_kpts0"], data["m_kpts1"]
M, info = poselib.estimate_homography(
pts0.numpy(),
pts1.numpy(),
pts0.detach().cpu().numpy(),
pts1.detach().cpu().numpy(),
{
"max_reproj_error": self.conf.ransac_th,
**OmegaConf.to_container(self.conf.options),
Expand Down
6 changes: 6 additions & 0 deletions gluefactory/utils/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,9 @@ def rbd(data: dict) -> dict:
k: v[0] if isinstance(v, (torch.Tensor, np.ndarray, list)) else v
for k, v in data.items()
}


def index_batch(tensor_dict):
batch_size = len(next(iter(tensor_dict.values())))
for i in range(batch_size):
yield map_tensor(tensor_dict, lambda t: t[i])
8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@ urls = {Repository = "https://github.com/cvg/glue-factory"}
[project.optional-dependencies]
extra = [
"pycolmap",
"poselib @ git+https://github.com/PoseLib/PoseLib.git",
"pytlsd @ git+https://github.com/iago-suarez/pytlsd.git",
"poselib @ git+https://github.com/PoseLib/PoseLib.git@9c8f3ca1baba69e19726cc7caded574873ec1f9e",
"pytlsd @ git+https://github.com/iago-suarez/pytlsd.git@v0.0.5",
"deeplsd @ git+https://github.com/cvg/DeepLSD.git",
"homography_est @ git+https://github.com/rpautrat/homography_est.git",
"homography_est @ git+https://github.com/rpautrat/homography_est.git@17b200d528e6aa8ac61a878a29265bf5f9d36c41",
]
dev = ["black", "flake8", "isort"]
dev = ["black", "flake8", "isort", "parameterized"]

[tool.setuptools.packages.find]
include = ["gluefactory*"]
Expand Down
Empty file added tests/__init__.py
Empty file.
Loading