Skip to content

Commit

Permalink
Create data folder before downloading MegaDepth1500 (#8)
Browse files Browse the repository at this point in the history
  • Loading branch information
sarlinpe authored Oct 9, 2023
1 parent 12640af commit f7b587e
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions gluefactory/eval/megadepth1500.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import zipfile
from collections import defaultdict
from collections.abc import Iterable
Expand All @@ -19,6 +20,8 @@
from .io import get_eval_parser, load_model, parse_eval_args
from .utils import eval_matches_epipolar, eval_poses, eval_relative_pose_robust

logger = logging.getLogger(__name__)


class MegaDepth1500Pipeline(EvalPipeline):
default_conf = {
Expand Down Expand Up @@ -56,11 +59,13 @@ class MegaDepth1500Pipeline(EvalPipeline):

def _init(self, conf):
if not (DATA_PATH / "megadepth1500").exists():
logger.info("Downloading the MegaDepth-1500 dataset.")
url = "https://cvg-data.inf.ethz.ch/megadepth/megadepth1500.zip"
zip_path = DATA_PATH / url.rsplit("/", 1)[-1]
zip_path.parent.mkdir(exist_ok=True, parents=True)
torch.hub.download_url_to_file(url, zip_path)
with zipfile.ZipFile(zip_path) as zip:
zip.extractall(DATA_PATH)
with zipfile.ZipFile(zip_path) as fid:
fid.extractall(DATA_PATH)
zip_path.unlink()

@classmethod
Expand Down Expand Up @@ -147,6 +152,8 @@ def run_eval(self, loader, pred_file):


if __name__ == "__main__":
from .. import logger # overwrite the logger

dataset_name = Path(__file__).stem
parser = get_eval_parser()
args = parser.parse_intermixed_args()
Expand Down

0 comments on commit f7b587e

Please sign in to comment.