Skip to content

Commit

Permalink
[UPDATE] minor updates
Browse files Browse the repository at this point in the history
README: 1.update badges 2.add Google Colab
requirements.txt: optimize and fix versions
run.py: exit when no input image
  • Loading branch information
markkua committed Dec 5, 2023
1 parent be0acfe commit 049e77e
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 73 deletions.
16 changes: 11 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
This repository represents the official implementation of the paper titled "Repurposing Diffusion-Based Image Generators for Monocular Depth Estimation".

[![Website](doc/badges/badge-website.svg)](https://marigoldmonodepth.github.io)
[![Paper](doc/badges/badge-pdf.svg)](https://arxiv.org/abs/2312.02145)
[![Paper](https://img.shields.io/badge/arXiv-PDF-b31b1b)](https://arxiv.org/abs/2312.02145)
[![Open In Colab](doc/badges/badge-colab.svg)](https://colab.research.google.com/drive/12G8reD13DdpMie5ZQlaFNo2WCGeNUH-u?usp=sharing)
[![HF Model](https://img.shields.io/badge/🤗%20Hugging%20Face-Model-yellow)](https://huggingface.co/Bingxin/Marigold)
[![License](doc/badges/badge-license.svg)](LICENSE)
[![License](https://img.shields.io/badge/License-CC_BY--NC--SA_4.0-929292)](LICENSE)
<!-- [![Website](https://img.shields.io/badge/Project-Website-1081c2)](https://arxiv.org/abs/2312.02145) -->
<!-- [![GitHub](https://img.shields.io/github/stars/prs-eth/Marigold?style=default&label=GitHub%20★&logo=github)](https://github.com/prs-eth/Marigold) -->
<!-- [![HF Space](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Space-blue)]() -->
<!-- [![Open In Colab](doc/badges/badge-colab.svg)]() -->
<!-- [![Docker](doc/badges/badge-docker.svg)]() -->
<!-- ### [Repurposing Diffusion-Based Image Generators for Monocular Depth Estimation]() -->

[Bingxin Ke](http://www.kebingxin.com/),
[Anton Obukhov](https://www.obukhov.ai/),
Expand All @@ -36,6 +36,12 @@ This code has been tested on:
- Python 3.10.12, PyTorch 2.0.1, CUDA 11.7, GeForce RTX 3090
- python 3.10.4, Pytorch 2.0.1, CUDA 11.7, GeForce RTX 4090

### 📦 Repository

```bash
git clone https://github.com/prs-eth/Marigold.git
cd Marigold
```

### 💻 Dependencies

Expand Down Expand Up @@ -113,4 +119,4 @@ python run.py \

This work is licensed under a Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License.

[<img src="doc/badges/badge-license.svg" height="20"/>](http://creativecommons.org/licenses/by-nc-sa/4.0/)
[![License](https://img.shields.io/badge/License-CC_BY--NC--SA_4.0-929292)](LICENSE)
27 changes: 0 additions & 27 deletions doc/badges/badge-license.svg

This file was deleted.

27 changes: 0 additions & 27 deletions doc/badges/badge-pdf.svg

This file was deleted.

13 changes: 5 additions & 8 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,13 @@ matplotlib
numpy==1.26.1
omegaconf
opencv-python
pandas
scipy==1.11.3
tabulate
tensorboard
torch==2.0.1
torchaudio
torchvision
torchshow
torchaudio==2.0.2
torchvision==0.15.2
tqdm
transformers
triton
transformers==4.32.1
triton==2.0.0
wandb==0.14.0
xformers
xformers==0.0.21
17 changes: 11 additions & 6 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,17 +100,22 @@
device = torch.device("cuda" if cuda_avail else "cpu")
print(f"device = {device}")

# -------------------- Model --------------------
model = MarigoldPipeline.from_pretrained(checkpoint_path)

model = model.to(device)

# -------------------- Data --------------------
rgb_filename_list = glob(os.path.join(input_rgb_dir, "*"))
rgb_filename_list = [
f for f in rgb_filename_list if os.path.splitext(f)[1].lower() in EXTENSION_LIST
]
print(f"Found {len(rgb_filename_list)} images")
n_images = len(rgb_filename_list)
if n_images > 0:
print(f"Found {n_images} images")
else:
print(f"No image found in '{input_rgb_dir}'")
exit(1)

# -------------------- Model --------------------
model = MarigoldPipeline.from_pretrained(checkpoint_path)

model = model.to(device)

# -------------------- Inference and saving --------------------
with torch.no_grad():
Expand Down

0 comments on commit 049e77e

Please sign in to comment.