Skip to content

Commit

Permalink
Merge pull request #136 from cvg/bindings
Browse files Browse the repository at this point in the history
hloc v1.3
  • Loading branch information
skydes authored Jan 4, 2022
2 parents 9bad6b4 + 845f1ac commit d416130
Show file tree
Hide file tree
Showing 49 changed files with 9,550 additions and 629 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,6 @@ __pycache__
*.egg-info
.ipynb_checkpoints
outputs/
third_party/netvlad
datasets/*
!datasets/sacre_coeur/
47 changes: 34 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,24 @@ With `hloc`, you can:

##

## Quick start ➡️ [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1MrVs9b8aQYODtOGkoaGNF9Nji3sbCNMQ)

Build 3D maps with Structure-from-Motion and localize any Internet image right from your browser! **You can now run `hloc` and COLMAP in Google Colab with GPU for free.** The notebook [`demo.ipynb`](https://colab.research.google.com/drive/1MrVs9b8aQYODtOGkoaGNF9Nji3sbCNMQ) shows how to run SfM and localization in just a few steps. Try it with your own data and let us know!

## Installation

`hloc` requires Python >=3.6, PyTorch >=1.1, and [COLMAP](https://colmap.github.io/index.html). Installing the package locally pulls the other dependencies:
```
`hloc` requires Python >=3.7 and PyTorch >=1.1. Installing the package locally pulls the other dependencies:

```bash
git clone --recursive https://github.com/cvg/Hierarchical-Localization/
cd Hierarchical-Localization/
python -m pip install -e .
```

All dependencies are listed in `requirements.txt`.
This codebase includes external local features as git submodules – don't forget to pull submodules with `git submodule update --init --recursive`. Your local features are based on TensorFlow? No problem! See [below](#using-your-own-local-features-or-matcher) for the steps.
All dependencies are listed in `requirements.txt`. **Starting with `hloc-v1.3`, installing COLMAP is not required anymore.** This repository includes external local features as git submodules – don't forget to pull submodules with `git submodule update --init --recursive`.

We also provide a Docker image that includes COLMAP and other dependencies:
```
We also provide a Docker image:
```bash
docker build -t hloc:latest .
docker run -it --rm -p 8888:8888 hloc:latest # for GPU support, add `--runtime=nvidia`
jupyter notebook --ip 0.0.0.0 --port 8888 --no-browser --allow-root
Expand Down Expand Up @@ -58,7 +62,10 @@ Strcture of the toolbox:
- `hloc/matchers/` : interfaces for feature matchers
- `hloc/pipelines/` : entire pipelines for multiple datasets

`hloc` can be imported as an external package with `import hloc` or from the command line with `python -m hloc.script`.
`hloc` can be imported as an external package with `import hloc` or called from the command line with:
```bash
python -m hloc.name_of_script --arg1 --arg2
```

## Tasks

Expand Down Expand Up @@ -87,7 +94,7 @@ We show in [`pipeline_SfM.ipynb`](https://nbviewer.jupyter.org/github/cvg/Hierar
## Results

- Supported local feature extractors: [SuperPoint](https://arxiv.org/abs/1712.07629), [D2-Net](https://arxiv.org/abs/1905.03561), [SIFT](https://www.cs.ubc.ca/~lowe/papers/ijcv04.pdf), and [R2D2](https://arxiv.org/abs/1906.06195).
- Supported feature matchers: [SuperGlue](https://arxiv.org/abs/1911.11763) and nearest neighbor search with ratio test, distance test, mutual check.
- Supported feature matchers: [SuperGlue](https://arxiv.org/abs/1911.11763) and nearest neighbor search with ratio test, distance test, and/or mutual check.
- Supported image retrieval: [NetVLAD](https://arxiv.org/abs/1511.07247) and [AP-GeM/DIR](https://github.com/naver/deep-image-retrieval).

Using NetVLAD for retrieval, we obtain the following best results:
Expand All @@ -109,7 +116,10 @@ Check out [visuallocalization.net/benchmark](https://www.visuallocalization.net/

## Supported datasets

We provide in [`hloc/pipelines/`](./hloc/pipelines) scripts to run the reconstruction and the localization on the following datasets: Aachen Day-Night (v1.0 and v1.1), InLoc, Extended CMU Seasons, RobotCar Seasons, 4Seasons, Cambridge Landmarks, and 7-Scenes.
We provide in [`hloc/pipelines/`](./hloc/pipelines) scripts to run the reconstruction and the localization on the following datasets: Aachen Day-Night (v1.0 and v1.1), InLoc, Extended CMU Seasons, RobotCar Seasons, 4Seasons, Cambridge Landmarks, and 7-Scenes. For example, after downloading the dataset [with the instructions given here](./hloc/pipelines/Aachen#installation), we can run the Aachen Day-Night pipeline with SuperPoint+SuperGlue using the command:
```bash
python -m hloc.pipelines.Aachen.pipeline [--outputs ./outputs/aachen]
```

## BibTex Citation

Expand Down Expand Up @@ -178,9 +188,21 @@ In a match file, each key corresponds to the string `path0.replace('/', '-')+'_'
## Versions

<details>
<summary>master (development)</summary>

Multiple bug fixes and minor improvements.
<summary>v1.3 (January 2022)</summary>

- Demo notebook in Google Colab
- Use the new pycolmap Reconstruction objects and pipeline API
- Do not require an installation of COLMAP anymore - pycolmap is enough
- Faster model reading and writing
- Fine-grained control over camera sharing via the `camera_mode` parameter
- Localization with unknown or inaccurate focal length
- Modular localization API with control over all estimator parameters
- 3D visualizations or camera frustums and points with plotly
- Package-specific logging in the hloc namespace
- Store the extracted features by default as fp16 instead of fp32
- Optionally fix a long-standing bug in SuperPoint descriptor sampling
- Add script to compute exhaustive pairs for reconstruction or localization
- Require pycolmap>=0.1.0 and Python>=3.7
</details>

<details>
Expand Down Expand Up @@ -211,7 +233,6 @@ Initial public version.

External contributions are very much welcome. Please follow the [PEP8 style guidelines](https://www.python.org/dev/peps/pep-0008/) using a linter like flake8. This is a non-exhaustive list of features that might be valuable additions:

- [ ] handle unknown query intrinsics (extraction from EXIF + refinement in PnP)
- [ ] support for GPS (extraction from EXIF + guided retrieval)
- [ ] covisibility clustering for InLoc
- [ ] visualization of the raw predictions (features and matches)
Expand Down
4 changes: 0 additions & 4 deletions datasets/.gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +0,0 @@
# Ignore everything in this directory
*
# Except this file
!.gitignore
3 changes: 3 additions & 0 deletions datasets/sacre_coeur/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Sacre Coeur demo

We provide here a subset of images depicting the Sacre Coeur. These images were obtained from the [Image Matching Challenge 2021](https://www.cs.ubc.ca/research/image-matching-challenge/2021/data/) and were originally collected by the [Yahoo Flickr Creative Commons 100M (YFCC) dataset](https://multimediacommons.wordpress.com/yfcc100m-core-dataset/).
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
8,614 changes: 8,614 additions & 0 deletions demo.ipynb

Large diffs are not rendered by default.

32 changes: 26 additions & 6 deletions hloc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,29 @@
import logging
import sys
from packaging import version

__version__ = '1.2'
__version__ = '1.3'

logging.basicConfig(stream=sys.stdout,
format='[%(asctime)s %(levelname)s] %(message)s',
datefmt='%m/%d/%Y %H:%M:%S',
level=logging.INFO)
formatter = logging.Formatter(
fmt='[%(asctime)s %(name)s %(levelname)s] %(message)s',
datefmt='%Y/%m/%d %H:%M:%S')
handler = logging.StreamHandler()
handler.setFormatter(formatter)
handler.setLevel(logging.INFO)

logger = logging.getLogger("hloc")
logger.setLevel(logging.INFO)
logger.addHandler(handler)
logger.propagate = False

try:
import pycolmap
except ImportError:
logger.warning('pycolmap is not installed, some features may not work.')
else:
minimal_version = version.parse('0.1.0')
found_version = version.parse(getattr(pycolmap, '__version__'))
if found_version < minimal_version:
logger.warning(
'hloc now requires pycolmap>=%s but found pycolmap==%s, '
'please upgrade with `pip install --upgrade pycolmap`',
minimal_version, found_version)
20 changes: 10 additions & 10 deletions hloc/colmap_from_nvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from collections import defaultdict
import numpy as np
from pathlib import Path
import logging

from . import logger
from .utils.read_write_model import Camera, Image, Point3D, CAMERA_MODEL_NAMES
from .utils.read_write_model import write_model

Expand All @@ -19,7 +19,7 @@ def recover_database_images_and_ids(database_path):
images[name] = image_id
cameras[name] = camera_id
db.close()
logging.info(
logger.info(
f'Found {len(images)} images and {len(cameras)} cameras in database.')
return images, cameras

Expand All @@ -45,7 +45,7 @@ def read_nvm_model(
with open(intrinsics_path, 'r') as f:
raw_intrinsics = f.readlines()

logging.info(f'Reading {len(raw_intrinsics)} cameras...')
logger.info(f'Reading {len(raw_intrinsics)} cameras...')
cameras = {}
for intrinsics in raw_intrinsics:
intrinsics = intrinsics.strip('\n').split(' ')
Expand All @@ -66,7 +66,7 @@ def read_nvm_model(
num_images = int(line)
assert num_images == len(cameras)

logging.info(f'Reading {num_images} images...')
logger.info(f'Reading {num_images} images...')
image_idx_to_db_image_id = []
image_data = []
i = 0
Expand All @@ -85,10 +85,10 @@ def read_nvm_model(
num_points = int(line)

if skip_points:
logging.info(f'Skipping {num_points} points.')
logger.info(f'Skipping {num_points} points.')
num_points = 0
else:
logging.info(f'Reading {num_points} points...')
logger.info(f'Reading {num_points} points...')
points3D = {}
image_idx_to_keypoints = defaultdict(list)
i = 0
Expand Down Expand Up @@ -123,7 +123,7 @@ def read_nvm_model(
pbar.update(1)
pbar.close()

logging.info('Parsing image data...')
logger.info('Parsing image data...')
images = {}
for i, data in enumerate(image_data):
# Skip the focal length. Skip the distortion and terminal 0.
Expand Down Expand Up @@ -169,14 +169,14 @@ def main(nvm, intrinsics, database, output, skip_points=False):

image_ids, camera_ids = recover_database_images_and_ids(database)

logging.info('Reading the NVM model...')
logger.info('Reading the NVM model...')
model = read_nvm_model(
nvm, intrinsics, image_ids, camera_ids, skip_points=skip_points)

logging.info('Writing the COLMAP model...')
logger.info('Writing the COLMAP model...')
output.mkdir(exist_ok=True, parents=True)
write_model(*model, path=str(output), ext='.bin')
logging.info('Done.')
logger.info('Done.')


if __name__ == '__main__':
Expand Down
33 changes: 20 additions & 13 deletions hloc/extract_features.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import argparse
import torch
from pathlib import Path
from typing import Dict, List, Union, Optional
import h5py
import logging
from types import SimpleNamespace
import cv2
import numpy as np
Expand All @@ -11,7 +11,7 @@
import collections.abc as collections
import PIL.Image

from . import extractors
from . import extractors, logger
from .utils.base_model import dynamic_load
from .utils.tools import map_tensor
from .utils.parsers import parse_image_lists
Expand Down Expand Up @@ -67,7 +67,7 @@
},
'r2d2': {
'output': 'feats-r2d2-n5000-r1024',
'model':{
'model': {
'name': 'r2d2',
'max_keypoints': 5000,
},
Expand Down Expand Up @@ -142,7 +142,7 @@ class ImageDataset(torch.utils.data.Dataset):
'grayscale': False,
'resize_max': None,
'resize_force': False,
'interpolation': 'cv2_linear', # switch to pil_linear for accuracy
'interpolation': 'cv2_area', # pil_linear is more accurate but slower
}

def __init__(self, root, conf, paths=None):
Expand All @@ -157,7 +157,7 @@ def __init__(self, root, conf, paths=None):
raise ValueError(f'Could not find any image in root: {root}.')
paths = sorted(list(set(paths)))
self.names = [i.relative_to(root).as_posix() for i in paths]
logging.info(f'Found {len(self.names)} images in root {root}.')
logger.info(f'Found {len(self.names)} images in root {root}.')
else:
if isinstance(paths, (Path, str)):
self.names = parse_image_lists(paths)
Expand Down Expand Up @@ -202,10 +202,15 @@ def __len__(self):


@torch.no_grad()
def main(conf, image_dir, export_dir=None, as_half=False,
image_list=None, feature_path=None):
logging.info('Extracting local features with configuration:'
f'\n{pprint.pformat(conf)}')
def main(conf: Dict,
image_dir: Path,
export_dir: Optional[Path] = None,
as_half: bool = True,
image_list: Optional[Union[Path, List[str]]] = None,
feature_path: Optional[Path] = None,
overwrite: bool = False) -> Path:
logger.info('Extracting local features with configuration:'
f'\n{pprint.pformat(conf)}')

loader = ImageDataset(image_dir, conf['preprocessing'], image_list)
loader = torch.utils.data.DataLoader(loader, num_workers=1)
Expand All @@ -214,9 +219,9 @@ def main(conf, image_dir, export_dir=None, as_half=False,
feature_path = Path(export_dir, conf['output']+'.h5')
feature_path.parent.mkdir(exist_ok=True, parents=True)
skip_names = set(list_h5_names(feature_path)
if feature_path.exists() else ())
if feature_path.exists() and not overwrite else ())
if set(loader.dataset.names).issubset(set(skip_names)):
logging.info('Skipping the extraction.')
logger.info('Skipping the extraction.')
return feature_path

device = 'cuda' if torch.cuda.is_available() else 'cpu'
Expand Down Expand Up @@ -245,20 +250,22 @@ def main(conf, image_dir, export_dir=None, as_half=False,

with h5py.File(str(feature_path), 'a') as fd:
try:
if name in fd:
del fd[name]
grp = fd.create_group(name)
for k, v in pred.items():
grp.create_dataset(k, data=v)
except OSError as error:
if 'No space left on device' in error.args[0]:
logging.error(
logger.error(
'Out of disk space: storing features on disk can take '
'significant space, did you enable the as_half flag?')
del grp, fd[name]
raise error

del pred

logging.info('Finished exporting features.')
logger.info('Finished exporting features.')
return feature_path


Expand Down
3 changes: 2 additions & 1 deletion hloc/extractors/netvlad.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from ..utils.base_model import BaseModel

logger = logging.getLogger(__name__)

netvlad_path = Path(__file__).parent / '../../third_party/netvlad'

Expand Down Expand Up @@ -65,7 +66,7 @@ def _init(self, conf):
checkpoint.parent.mkdir(exist_ok=True)
link = self.dir_models[conf['model_name']]
cmd = ['wget', link, '-O', str(checkpoint)]
logging.info(f'Downloading the NetVLAD model with `{cmd}`.')
logger.info(f'Downloading the NetVLAD model with `{cmd}`.')
subprocess.run(cmd, check=True)

# Create the network.
Expand Down
18 changes: 18 additions & 0 deletions hloc/extractors/superpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,34 @@ def sample_descriptors(keypoints, descriptors, s: int = 8):
superpoint.sample_descriptors = sample_descriptors


# The original keypoint sampling is incorrect. We patch it here but
# we don't fix it upstream to not impact exisiting evaluations.
def sample_descriptors_fix_sampling(keypoints, descriptors, s: int = 8):
""" Interpolate descriptors at keypoint locations """
b, c, h, w = descriptors.shape
keypoints = (keypoints + 0.5) / (keypoints.new_tensor([w, h]) * s)
keypoints = keypoints * 2 - 1 # normalize to (-1, 1)
descriptors = torch.nn.functional.grid_sample(
descriptors, keypoints.view(b, 1, -1, 2),
mode='bilinear', align_corners=False)
descriptors = torch.nn.functional.normalize(
descriptors.reshape(b, c, -1), p=2, dim=1)
return descriptors


class SuperPoint(BaseModel):
default_conf = {
'nms_radius': 4,
'keypoint_threshold': 0.005,
'max_keypoints': -1,
'remove_borders': 4,
'fix_sampling': False,
}
required_inputs = ['image']

def _init(self, conf):
if conf['fix_sampling']:
superpoint.sample_descriptors = sample_descriptors_fix_sampling
self.net = superpoint.SuperPoint(conf)

def _forward(self, data):
Expand Down
Loading

0 comments on commit d416130

Please sign in to comment.