diff --git a/examples/datasets/download_dataset.py b/examples/datasets/download_dataset.py index 81b6fe559..29d4031b7 100755 --- a/examples/datasets/download_dataset.py +++ b/examples/datasets/download_dataset.py @@ -1,22 +1,35 @@ -"""Script to download benchmark dataset(s)""" +""" +Script to download benchmark dataset(s) + +By default, this script downloads the 'mipnerf360' dataset. +You can specify a different dataset to download using the --dataset option. +If you want to download all available datasets, you can set the --dataset option to 'all'. +""" import os import subprocess from dataclasses import dataclass from pathlib import Path -from typing import Literal - +from typing import Literal, List import tyro # dataset names -dataset_names = Literal["mipnerf360"] +dataset_names = Literal["mipnerf360", "mipnerf360_extra", "tandt", "deepblending", "all"] # dataset urls -urls = {"mipnerf360": "http://storage.googleapis.com/gresearch/refraw360/360_v2.zip"} +urls = { + "mipnerf360": "http://storage.googleapis.com/gresearch/refraw360/360_v2.zip", + "mipnerf360_extra": "https://storage.googleapis.com/gresearch/refraw360/360_extra_scenes.zip", + "tandt_db": "https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/datasets/input/tandt_db.zip" +} # rename maps -dataset_rename_map = {"mipnerf360": "360_v2"} - +dataset_rename_map = { + "mipnerf360": "360_v2", + "mipnerf360_extra": "360_v2", + "tandt": "tandt", + "deepblending": "db" +} @dataclass class DownloadData: @@ -25,52 +38,66 @@ class DownloadData: def main(self): self.save_dir.mkdir(parents=True, exist_ok=True) - self.dataset_download(self.dataset) - - def dataset_download(self, dataset: dataset_names): - (self.save_dir / dataset_rename_map[dataset]).mkdir(parents=True, exist_ok=True) + if self.dataset == "all": + for dataset in urls.keys(): + self.dataset_download(dataset) + else: + self.dataset_download(self.dataset) - file_name = Path(urls[dataset]).name + def dataset_download(self, dataset: str): + if dataset in ["tandt", "deepblending", "tandt_db"]: + url = urls["tandt_db"] + file_name = Path(url).name + extract_dir = self.save_dir + else: + url = urls[dataset] + file_name = Path(url).name + extract_dir = self.save_dir / dataset_rename_map[dataset] # download download_command = [ "wget", "-P", - str(self.save_dir / dataset_rename_map[dataset]), - urls[dataset], + str(extract_dir), + url, ] try: subprocess.run(download_command, check=True) - print("File file downloaded succesfully.") + print(f"File {file_name} downloaded successfully.") except subprocess.CalledProcessError as e: print(f"Error downloading file: {e}") + return # if .zip - if Path(urls[dataset]).suffix == ".zip": + if Path(url).suffix == ".zip": extract_command = [ "unzip", - self.save_dir / dataset_rename_map[dataset] / file_name, + "-o", + extract_dir / file_name, "-d", - self.save_dir / dataset_rename_map[dataset], + extract_dir, ] # if .tar else: extract_command = [ "tar", "-xvzf", - self.save_dir / dataset_rename_map[dataset] / file_name, + extract_dir / file_name, "-C", - self.save_dir / dataset_rename_map[dataset], + extract_dir, ] - # extract try: subprocess.run(extract_command, check=True) - os.remove(self.save_dir / dataset_rename_map[dataset] / file_name) + os.remove(extract_dir / file_name) print("Extraction complete.") except subprocess.CalledProcessError as e: print(f"Extraction failed: {e}") + # For tandt_db, we need to rename the extracted folders + if dataset in ["tandt", "deepblending"]: + os.rename(self.save_dir / "tandt", self.save_dir / dataset_rename_map["tandt"]) + os.rename(self.save_dir / "db", self.save_dir / dataset_rename_map["deepblending"]) if __name__ == "__main__": - tyro.cli(DownloadData).main() + tyro.cli(DownloadData).main() \ No newline at end of file