Skip to content

Commit

Permalink
[UPDATE] enable model loading from hub
Browse files Browse the repository at this point in the history
add from_pretrined function to unify entrance
enable loading from hugging face online model
by default downloading from huggingface hub
  • Loading branch information
markkua committed Dec 5, 2023
1 parent 2d4aa39 commit be0acfe
Show file tree
Hide file tree
Showing 6 changed files with 204 additions and 105 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ output/
temp/
wandb/
venv/
cache/

**/.ipynb_checkpoints/
.vscode/
Expand Down
40 changes: 29 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ This repository represents the official implementation of the paper titled "Repu

[![Website](doc/badges/badge-website.svg)](https://marigoldmonodepth.github.io)
[![Paper](doc/badges/badge-pdf.svg)](https://arxiv.org/abs/2312.02145)
[![HF Model](https://img.shields.io/badge/🤗%20Hugging%20Face-Model-yellow)](https://huggingface.co/Bingxin/Marigold)
[![License](doc/badges/badge-license.svg)](LICENSE)
<!-- [![GitHub](https://img.shields.io/github/stars/prs-eth/Marigold?style=default&label=GitHub%20★&logo=github)](https://github.com/prs-eth/Marigold) -->
<!-- [![HF Space](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Space-blue)]() -->
Expand Down Expand Up @@ -35,37 +36,34 @@ This code has been tested on:
- Python 3.10.12, PyTorch 2.0.1, CUDA 11.7, GeForce RTX 3090
- python 3.10.4, Pytorch 2.0.1, CUDA 11.7, GeForce RTX 4090

💻 Dependencies

### 💻 Dependencies

```bash
python -m venv venv/marigold
source venv/marigold/bin/activate
pip install -r requirements.txt
```

🚩 Checkpoint

```bash
bash script/download_weights.sh
```

## 🚀 Inference on in-the-wild images

📷 Sample images
### 📷 Sample images

```bash
bash script/download_sample_data.sh
```

🎮 Inference
### 🎮 Inference

This script will automatically download the [checkpoint](https://huggingface.co/Bingxin/Marigold).

```bash
python run.py \
--input_rgb_dir data/in-the-wild_example\
--output_dir output/in-the-wild_example
```

⚙️ Inference settings
### ⚙️ Inference settings

- The inference script by default will resize the input images and resize back to the original resolution.

Expand All @@ -81,6 +79,26 @@ python run.py \
- `--seed`: Random seed, can be set to ensure reproducibility. Default: None (using current time as random seed).
- `--depth_cmap`: Colormap used to colorize the depth prediction. Default: Spectral.

- The model cache directory can be controlled by environment variable `HF_HOME`, for example:

```bash
export HF_HOME=$(pwd)/checkpoint
```

### ⬇ Using local checkpoint

```bash
# Download checkpoint
bash script/download_weights.sh
```

```bash
python run.py \
--checkpoint checkpoint/Marigold_v1_merged \
--input_rgb_dir data/in-the-wild_example\
--output_dir output/in-the-wild_example
```

## 🎓 Citation

```bibtex
Expand All @@ -95,4 +113,4 @@ python run.py \

This work is licensed under a Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License.

[<img src="doc/badges/badge-license.svg" height="20"/>](http://creativecommons.org/licenses/by-nc-sa/4.0/)
[<img src="doc/badges/badge-license.svg" height="20"/>](http://creativecommons.org/licenses/by-nc-sa/4.0/)
14 changes: 7 additions & 7 deletions run.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Script for inference on (in-the-wild) images

# Author: Bingxin Ke
# Last modified: 2023-12-04
# Last modified: 2023-12-05


import argparse
Expand All @@ -26,7 +26,7 @@
if "__main__" == __name__:
# -------------------- Arguments --------------------
parser = argparse.ArgumentParser(description="Run single-image depth estimation using Marigold.")
parser.add_argument("--checkpoint", type=str, default="checkpoint/Marigold_v1_merged", help="Path to checkpoint.")
parser.add_argument("--checkpoint", type=str, default="Bingxin/Marigold", help="Checkpoint path or hub name.")

parser.add_argument("--input_rgb_dir", type=str, required=True, help="Path to the input image folder.")

Expand Down Expand Up @@ -100,18 +100,18 @@
device = torch.device("cuda" if cuda_avail else "cpu")
print(f"device = {device}")

# -------------------- Model --------------------
model = MarigoldPipeline.from_pretrained(checkpoint_path)

model = model.to(device)

# -------------------- Data --------------------
rgb_filename_list = glob(os.path.join(input_rgb_dir, "*"))
rgb_filename_list = [
f for f in rgb_filename_list if os.path.splitext(f)[1].lower() in EXTENSION_LIST
]
print(f"Found {len(rgb_filename_list)} images")

# -------------------- Model --------------------
model = MarigoldPipeline(unet_pretrained_path=os.path.join(checkpoint_path, "unet"))

model = model.to(device)

# -------------------- Inference and saving --------------------
with torch.no_grad():
os.makedirs(output_dir, exist_ok=True)
Expand Down
Loading

0 comments on commit be0acfe

Please sign in to comment.