diff --git a/gluefactory/datasets/eth3d.py b/gluefactory/datasets/eth3d.py index 44fd73f8..953d775e 100644 --- a/gluefactory/datasets/eth3d.py +++ b/gluefactory/datasets/eth3d.py @@ -1,6 +1,7 @@ """ ETH3D multi-view benchmark, used for line matching evaluation. """ + import logging import os import shutil diff --git a/gluefactory/datasets/hpatches.py b/gluefactory/datasets/hpatches.py index 80846603..cf4c7993 100644 --- a/gluefactory/datasets/hpatches.py +++ b/gluefactory/datasets/hpatches.py @@ -1,6 +1,7 @@ """ Simply load images from a folder or nested folders (does not have any split). """ + import argparse import logging import tarfile diff --git a/gluefactory/geometry/gt_generation.py b/gluefactory/geometry/gt_generation.py index 21390cd7..b80a7778 100644 --- a/gluefactory/geometry/gt_generation.py +++ b/gluefactory/geometry/gt_generation.py @@ -375,9 +375,9 @@ def gt_line_matches_from_pose_depth( all_in_batch = ( torch.arange(b_size)[:, None].repeat(1, assignation.shape[-1]).flatten() ) - positive[ - all_in_batch, assignation[:, 0].flatten(), assignation[:, 1].flatten() - ] = True + positive[all_in_batch, assignation[:, 0].flatten(), assignation[:, 1].flatten()] = ( + True + ) m0 = assignation.new_full((b_size, n_lines0), unmatched, dtype=torch.long) m0.scatter_(-1, assignation[:, 0], assignation[:, 1]) diff --git a/gluefactory/models/cache_loader.py b/gluefactory/models/cache_loader.py index b345a997..837421ba 100644 --- a/gluefactory/models/cache_loader.py +++ b/gluefactory/models/cache_loader.py @@ -47,9 +47,11 @@ def pad_line_features(pred, seq_l: int = None): def recursive_load(grp, pkeys): return { - k: torch.from_numpy(grp[k].__array__()) - if isinstance(grp[k], h5py.Dataset) - else recursive_load(grp[k], list(grp.keys())) + k: ( + torch.from_numpy(grp[k].__array__()) + if isinstance(grp[k], h5py.Dataset) + else recursive_load(grp[k], list(grp.keys())) + ) for k in pkeys } @@ -108,9 +110,12 @@ def _forward(self, data): pred = recursive_load(grp, pkeys) if self.numeric_dtype is not None: pred = { - k: v - if not isinstance(v, torch.Tensor) or not torch.is_floating_point(v) - else v.to(dtype=self.numeric_dtype) + k: ( + v + if not isinstance(v, torch.Tensor) + or not torch.is_floating_point(v) + else v.to(dtype=self.numeric_dtype) + ) for k, v in pred.items() } pred = batch_to_device(pred, device) diff --git a/gluefactory/models/extractors/aliked.py b/gluefactory/models/extractors/aliked.py index 80cd348a..254a434e 100644 --- a/gluefactory/models/extractors/aliked.py +++ b/gluefactory/models/extractors/aliked.py @@ -717,9 +717,11 @@ def _init(self, conf): radius=conf.nms_radius, top_k=-1 if conf.detection_threshold > 0 else conf.max_num_keypoints, scores_th=conf.detection_threshold, - n_limit=conf.max_num_keypoints - if conf.max_num_keypoints > 0 - else self.n_limit_max, + n_limit=( + conf.max_num_keypoints + if conf.max_num_keypoints > 0 + else self.n_limit_max + ), ) # load pretrained diff --git a/gluefactory/models/extractors/superpoint_open.py b/gluefactory/models/extractors/superpoint_open.py index 1f960407..434e0a1d 100644 --- a/gluefactory/models/extractors/superpoint_open.py +++ b/gluefactory/models/extractors/superpoint_open.py @@ -5,6 +5,7 @@ The implementation of this model and its trained weights are made available under the MIT license. """ + from collections import OrderedDict from types import SimpleNamespace diff --git a/gluefactory/models/lines/wireframe.py b/gluefactory/models/lines/wireframe.py index ac0d0b5a..8f541c6a 100644 --- a/gluefactory/models/lines/wireframe.py +++ b/gluefactory/models/lines/wireframe.py @@ -256,9 +256,9 @@ def _forward(self, data): associativity = torch.eye( len(all_points[-1]), dtype=torch.bool, device=device ) - associativity[ - : n_true_junctions[bs], : n_true_junctions[bs] - ] = line_association[bs][: n_true_junctions[bs], : n_true_junctions[bs]] + associativity[: n_true_junctions[bs], : n_true_junctions[bs]] = ( + line_association[bs][: n_true_junctions[bs], : n_true_junctions[bs]] + ) pl_associativity.append(associativity) all_points = torch.stack(all_points, dim=0) diff --git a/tests/test_integration.py b/tests/test_integration.py index e459ada5..3592cff1 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -35,9 +35,11 @@ def create_input_data(cv_img0, cv_img1, device): data = {"view0": ip(img0), "view1": ip(img1)} data = map_tensor( data, - lambda t: t[None].to(device) - if isinstance(t, Tensor) - else torch.from_numpy(t)[None].to(device), + lambda t: ( + t[None].to(device) + if isinstance(t, Tensor) + else torch.from_numpy(t)[None].to(device) + ), ) return data