diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..3a72f6d --- /dev/null +++ b/.gitignore @@ -0,0 +1,134 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# I/O +input/ +output/ +weights/ \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..fbd5847 --- /dev/null +++ b/README.md @@ -0,0 +1,130 @@ +# Monocular Visual-Inertial Depth Estimation + +This repository contains code and models for our paper: + +> Monocular Visual-Inertial Depth Estimation +> Diana Wofk, René Ranftl, Matthias Müller, Vladlen Koltun + +## Introduction + +![Methodology Diagram](figures/methodology_diagram.png) + +We present a visual-inertial depth estimation pipeline that integrates monocular depth estimation and visual-inertial odometry to produce dense depth estimates with metric scale. Our approach consists of three stages: (1) input processing, where RGB and IMU data feed into monocular depth estimation alongside visual-inertial odometry, (2) global scale and shift alignment, where monocular depth estimates are fitted to sparse depth from VIO in a least-squares manner, and (3) learning-based dense scale alignment, where globally-aligned depth is locally realigned using a dense scale map regressed by the ScaleMapLearner (SML). The images at the bottom in the diagram above illustrate a VOID sample being processed through our pipeline; from left to right: the input RGB, ground truth depth, sparse depth from VIO, globally-aligned depth, scale map scaffolding, dense scale map regressed by SML, final depth output. + +![Teaser Figure](figures/teaser_figure.png) + +## Setup + +1) Setup dependencies: + + ```shell + conda env create -f environment.yaml + conda activate vi-depth + ``` + +2) Pick one or more ScaleMapLearner (SML) models and download the corresponding weights to the `weights` folder. + + | Depth Predictor | SML on VOID 150 | SML on VOID 500 | SML on VOID 1500 | + | :--- | :----: | :----: | :----: | + | DPT-BEiT-Large | [model](https://github.com/isl-org/VI-Depth/releases/download/v1/sml_model.dpredictor.dpt_beit_large_512.nsamples.150.ckpt) | [model](https://github.com/isl-org/VI-Depth/releases/download/v1/sml_model.dpredictor.dpt_beit_large_512.nsamples.500.ckpt) | [model](https://github.com/isl-org/VI-Depth/releases/download/v1/sml_model.dpredictor.dpt_beit_large_512.nsamples.1500.ckpt) | + | DPT-SwinV2-Large | [model](https://github.com/isl-org/VI-Depth/releases/download/v1/sml_model.dpredictor.dpt_swin2_large_384.nsamples.150.ckpt) | [model](https://github.com/isl-org/VI-Depth/releases/download/v1/sml_model.dpredictor.dpt_swin2_large_384.nsamples.500.ckpt) | [model](https://github.com/isl-org/VI-Depth/releases/download/v1/sml_model.dpredictor.dpt_swin2_large_384.nsamples.1500.ckpt) | + | DPT-Large | [model](https://github.com/isl-org/VI-Depth/releases/download/v1/sml_model.dpredictor.dpt_large.nsamples.150.ckpt) | [model](https://github.com/isl-org/VI-Depth/releases/download/v1/sml_model.dpredictor.dpt_large.nsamples.500.ckpt) | [model](https://github.com/isl-org/VI-Depth/releases/download/v1/sml_model.dpredictor.dpt_large.nsamples.1500.ckpt) | + | DPT-Hybrid | [model](https://github.com/isl-org/VI-Depth/releases/download/v1/sml_model.dpredictor.dpt_hybrid.nsamples.150.ckpt)* | [model](https://github.com/isl-org/VI-Depth/releases/download/v1/sml_model.dpredictor.dpt_hybrid.nsamples.500.ckpt) | [model](https://github.com/isl-org/VI-Depth/releases/download/v1/sml_model.dpredictor.dpt_hybrid.nsamples.1500.ckpt) | + | DPT-SwinV2-Tiny | [model](https://github.com/isl-org/VI-Depth/releases/download/v1/sml_model.dpredictor.dpt_swin2_tiny_256.nsamples.150.ckpt) | [model](https://github.com/isl-org/VI-Depth/releases/download/v1/sml_model.dpredictor.dpt_swin2_tiny_256.nsamples.500.ckpt) | [model](https://github.com/isl-org/VI-Depth/releases/download/v1/sml_model.dpredictor.dpt_swin2_tiny_256.nsamples.1500.ckpt) | + | DPT-LeViT | [model](https://github.com/isl-org/VI-Depth/releases/download/v1/sml_model.dpredictor.dpt_levit_224.nsamples.150.ckpt) | [model](https://github.com/isl-org/VI-Depth/releases/download/v1/sml_model.dpredictor.dpt_levit_224.nsamples.500.ckpt) | [model](https://github.com/isl-org/VI-Depth/releases/download/v1/sml_model.dpredictor.dpt_levit_224.nsamples.1500.ckpt) | + | MiDaS-small | [model](https://github.com/isl-org/VI-Depth/releases/download/v1/sml_model.dpredictor.midas_small.nsamples.150.ckpt) | [model](https://github.com/isl-org/VI-Depth/releases/download/v1/sml_model.dpredictor.midas_small.nsamples.500.ckpt) | [model](https://github.com/isl-org/VI-Depth/releases/download/v1/sml_model.dpredictor.midas_small.nsamples.1500.ckpt) | + + *Also available with pretraining on TartanAir: [model](https://github.com/isl-org/VI-Depth/releases/download/v1/sml_model.dpredictor.dpt_hybrid.nsamples.150.pretrained.ckpt) + +## Inference + +1) Place inputs into the `input` folder. An input image and corresponding sparse metric depth map are expected: + + ```bash + input + ├── image # RGB image + │ ├── .png + │ └── ... + └── sparse_depth # sparse metric depth map + ├── .png # as 16b PNG + └── ... + ``` + + The `load_sparse_depth` function in `run.py` may need to be modified depending on the format in which sparse depth is stored. By default, the depth storage method [used in the VOID dataset](https://github.com/alexklwong/void-dataset/blob/master/src/data_utils.py) is assumed. + +2) Run the `run.py` script as follows: + + ```bash + DEPTH_PREDICTOR="dpt_beit_large_512" + NSAMPLES=150 + SML_MODEL_PATH="weights/sml_model.dpredictor.${DEPTH_PREDICTOR}.nsamples.${NSAMPLES}.ckpt" + + python run.py -dp $DEPTH_PREDICTOR -ns $NSAMPLES -sm $SML_MODEL_PATH --save-output + ``` + +3) The `--save-output` flag enables saving outputs to the `output` folder. By default, the following outputs will be saved per sample: + + ```bash + output + ├── ga_depth # metric depth map after global alignment + │ ├── .pfm # as PFM + │ ├── .png # as 16b PNG + │ └── ... + └── sml_depth # metric depth map output by SML + ├── .pfm # as PFM + ├── .png # as 16b PNG + └── ... + ``` + +## Evaluation + +Models provided in this repo were trained on the VOID dataset. +1) Download the VOID dataset following [the instructions in the VOID dataset repo](https://github.com/alexklwong/void-dataset#downloading-void). +2) To evaluate on VOID test sets, run the `evaluate.py` script as follows: + + ```bash + DATASET_PATH="/path/to/void_release/" + + DEPTH_PREDICTOR="dpt_beit_large_512" + NSAMPLES=150 + SML_MODEL_PATH="weights/sml_model.dpredictor.${DEPTH_PREDICTOR}.nsamples.${NSAMPLES}.ckpt" + + python evaluate.py -ds $DATASET_PATH -dp $DEPTH_PREDICTOR -ns $NSAMPLES -sm $SML_MODEL_PATH + ``` + + Results for the example shown above: + + ``` + Averaging metrics for globally-aligned depth over 800 samples + Averaging metrics for SML-aligned depth over 800 samples + +---------+----------+----------+ + | metric | GA Only | GA+SML | + +---------+----------+----------+ + | RMSE | 191.36 | 142.85 | + | MAE | 115.84 | 76.95 | + | AbsRel | 0.069 | 0.046 | + | iRMSE | 72.70 | 57.13 | + | iMAE | 49.32 | 34.25 | + | iAbsRel | 0.071 | 0.048 | + +---------+----------+----------+ + ``` + + To evaluate on VOID test sets at different densities (void_150, void_500, void_1500), change the `NSAMPLES` argument above accordingly. + +## Citation + +If you reference our work, please consider citing the following: + +```bib +@inproceedings{wofk2023videpth, + author = {{Wofk, Diana and Ranftl, Ren\'{e} and M{\"u}ller, Matthias and Koltun, Vladlen}}, + title = {{Monocular Visual-Inertial Depth Estimation}}, + booktitle = {{IEEE International Conference on Robotics and Automation (ICRA)}}, + year = {{2023}} +} +``` + +## Acknowledgements + +Our work builds on and uses code from [MiDaS](https://github.com/isl-org/MiDaS), [timm](https://github.com/rwightman/pytorch-image-models), and [PyTorch Lightning](https://lightning.ai/docs/pytorch/stable/). We'd like to thank the authors for making these libraries and frameworks available. + diff --git a/environment.yaml b/environment.yaml new file mode 100644 index 0000000..e74463d --- /dev/null +++ b/environment.yaml @@ -0,0 +1,18 @@ +name: vi-depth +channels: + - pytorch + - defaults +dependencies: + - nvidia::cudatoolkit=11.7 + - python=3.10.8 + - pytorch::pytorch=1.13.0 + - torchvision=0.14.0 + - pip=22.3.1 + - numpy=1.23.4 + - pip: + - opencv-python==4.6.0.66 + - scipy==1.10.1 + - timm==0.6.12 + - pytorch-lightning==1.9.0 + - imageio==2.25.0 + - prettytable==3.6.0 \ No newline at end of file diff --git a/evaluate.py b/evaluate.py new file mode 100644 index 0000000..a717566 --- /dev/null +++ b/evaluate.py @@ -0,0 +1,138 @@ +import os +import argparse + +import torch +import imageio +import numpy as np + +from tqdm import tqdm +from PIL import Image + +import modules.midas.utils as utils + +import pipeline +import metrics + +def evaluate(dataset_path, depth_predictor, nsamples, sml_model_path): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print("device: %s" % device) + + # ranges for VOID + min_depth, max_depth = 0.2, 5.0 + min_pred, max_pred = 0.1, 8.0 + + # instantiate method + method = pipeline.VIDepth( + depth_predictor, nsamples, sml_model_path, + min_pred, max_pred, min_depth, max_depth, device + ) + + # get inputs + with open(f"{dataset_path}/void_{nsamples}/test_image.txt") as f: + test_image_list = [line.rstrip() for line in f] + + # initialize error aggregators + avg_error_w_int_depth = metrics.ErrorMetricsAverager() + avg_error_w_pred = metrics.ErrorMetricsAverager() + + # iterate through inputs list + for i in tqdm(range(len(test_image_list))): + + # image + input_image_fp = os.path.join(dataset_path, test_image_list[i]) + input_image = utils.read_image(input_image_fp) + + # sparse depth + input_sparse_depth_fp = input_image_fp.replace("image", "sparse_depth") + input_sparse_depth = np.array(Image.open(input_sparse_depth_fp), dtype=np.float32) / 256.0 + input_sparse_depth[input_sparse_depth <= 0] = 0.0 + + # sparse depth validity map + validity_map_fp = input_image_fp.replace("image", "validity_map") + validity_map = np.array(Image.open(validity_map_fp), dtype=np.float32) + assert(np.all(np.unique(validity_map) == [0, 256])) + validity_map[validity_map > 0] = 1 + + # target (ground truth) depth + target_depth_fp = input_image_fp.replace("image", "ground_truth") + target_depth = np.array(Image.open(target_depth_fp), dtype=np.float32) / 256.0 + target_depth[target_depth <= 0] = 0.0 + + # target depth valid/mask + mask = (target_depth < max_depth) + if min_depth is not None: + mask *= (target_depth > min_depth) + target_depth[~mask] = np.inf # set invalid depth + target_depth = 1.0 / target_depth + + # run pipeline + output = method.run(input_image, input_sparse_depth, validity_map, device) + + # compute error metrics using intermediate (globally aligned) depth + error_w_int_depth = metrics.ErrorMetrics() + error_w_int_depth.compute( + estimate = output["ga_depth"], + target = target_depth, + valid = mask.astype(np.bool), + ) + + # compute error metrics using SML output depth + error_w_pred = metrics.ErrorMetrics() + error_w_pred.compute( + estimate = output["sml_depth"], + target = target_depth, + valid = mask.astype(np.bool), + ) + + # accumulate error metrics + avg_error_w_int_depth.accumulate(error_w_int_depth) + avg_error_w_pred.accumulate(error_w_pred) + + + # compute average error metrics + print("Averaging metrics for globally-aligned depth over {} samples".format( + avg_error_w_int_depth.total_count + )) + avg_error_w_int_depth.average() + + print("Averaging metrics for SML-aligned depth over {} samples".format( + avg_error_w_pred.total_count + )) + avg_error_w_pred.average() + + from prettytable import PrettyTable + summary_tb = PrettyTable() + summary_tb.field_names = ["metric", "GA Only", "GA+SML"] + + summary_tb.add_row(["RMSE", f"{avg_error_w_int_depth.rmse_avg:7.2f}", f"{avg_error_w_pred.rmse_avg:7.2f}"]) + summary_tb.add_row(["MAE", f"{avg_error_w_int_depth.mae_avg:7.2f}", f"{avg_error_w_pred.mae_avg:7.2f}"]) + summary_tb.add_row(["AbsRel", f"{avg_error_w_int_depth.absrel_avg:8.3f}", f"{avg_error_w_pred.absrel_avg:8.3f}"]) + summary_tb.add_row(["iRMSE", f"{avg_error_w_int_depth.inv_rmse_avg:7.2f}", f"{avg_error_w_pred.inv_rmse_avg:7.2f}"]) + summary_tb.add_row(["iMAE", f"{avg_error_w_int_depth.inv_mae_avg:7.2f}", f"{avg_error_w_pred.inv_mae_avg:7.2f}"]) + summary_tb.add_row(["iAbsRel", f"{avg_error_w_int_depth.inv_absrel_avg:8.3f}", f"{avg_error_w_pred.inv_absrel_avg:8.3f}"]) + + print(summary_tb) + + +if __name__=="__main__": + + parser = argparse.ArgumentParser() + + parser.add_argument('-ds', '--dataset-path', type=str, default='/path/to/void_release/', + help='Path to VOID release dataset.') + parser.add_argument('-dp', '--depth-predictor', type=str, default='midas_small', + help='Name of depth predictor to use in pipeline.') + parser.add_argument('-ns', '--nsamples', type=int, default=150, + help='Number of sparse metric depth samples available.') + parser.add_argument('-sm', '--sml-model-path', type=str, default='', + help='Path to trained SML model weights.') + + args = parser.parse_args() + print(args) + + evaluate( + args.dataset_path, + args.depth_predictor, + args.nsamples, + args.sml_model_path, + ) \ No newline at end of file diff --git a/figures/methodology_diagram.png b/figures/methodology_diagram.png new file mode 100644 index 0000000..828801e Binary files /dev/null and b/figures/methodology_diagram.png differ diff --git a/figures/teaser_figure.png b/figures/teaser_figure.png new file mode 100644 index 0000000..61739b6 Binary files /dev/null and b/figures/teaser_figure.png differ diff --git a/input/image/.placeholder b/input/image/.placeholder new file mode 100644 index 0000000..e69de29 diff --git a/input/sparse_depth/.placeholder b/input/sparse_depth/.placeholder new file mode 100644 index 0000000..e69de29 diff --git a/metrics.py b/metrics.py new file mode 100644 index 0000000..5d3655b --- /dev/null +++ b/metrics.py @@ -0,0 +1,76 @@ +import numpy as np +import torch + +def rmse(estimate, target): + return np.sqrt(np.mean((estimate - target) ** 2)) + +def mae(estimate, target): + return np.mean(np.abs(estimate - target)) + +def absrel(estimate, target): + return np.mean(np.abs(estimate - target) / target) + +def inv_rmse(estimate, target): + return np.sqrt(np.mean((1.0/estimate - 1.0/target) ** 2)) + +def inv_mae(estimate, target): + return np.mean(np.abs(1.0/estimate - 1.0/target)) + +def inv_absrel(estimate, target): + return np.mean((np.abs(1.0/estimate - 1.0/target)) / (1.0/target)) + +class ErrorMetrics(object): + def __init__(self): + # initialize by setting to worst values + self.rmse, self.mae, self.absrel = np.inf, np.inf, np.inf + self.inv_rmse, self.inv_mae, self.inv_absrel = np.inf, np.inf, np.inf + + def compute(self, estimate, target, valid): + # apply valid masks + estimate = estimate[valid] + target = target[valid] + + # estimate and target will be in inverse space, convert to regular + estimate = 1.0/estimate + target = 1.0/target + + # depth error, estimate in meters, convert units to mm + self.rmse = rmse(1000.0*estimate, 1000.0*target) + self.mae = mae(1000.0*estimate, 1000.0*target) + self.absrel = absrel(1000.0*estimate, 1000.0*target) + + # inverse depth error, estimate in meters, convert units to 1/km + self.inv_rmse = inv_rmse(0.001*estimate, 0.001*target) + self.inv_mae = inv_mae(0.001*estimate, 0.001*target) + self.inv_absrel = inv_absrel(0.001*estimate, 0.001*target) + +class ErrorMetricsAverager(object): + def __init__(self): + # initialize avg accumulators to zero + self.rmse_avg, self.mae_avg, self.absrel_avg = 0, 0, 0 + self.inv_rmse_avg, self.inv_mae_avg, self.inv_absrel_avg = 0, 0, 0 + self.total_count = 0 + + def accumulate(self, error_metrics): + # adds to accumulators from ErrorMetrics object + assert isinstance(error_metrics, ErrorMetrics) + + self.rmse_avg += error_metrics.rmse + self.mae_avg += error_metrics.mae + self.absrel_avg += error_metrics.absrel + + self.inv_rmse_avg += error_metrics.inv_rmse + self.inv_mae_avg += error_metrics.inv_mae + self.inv_absrel_avg += error_metrics.inv_absrel + + self.total_count += 1 + + def average(self): + # print(f"Averaging depth metrics over {self.total_count} samples") + self.rmse_avg = self.rmse_avg / self.total_count + self.mae_avg = self.mae_avg / self.total_count + self.absrel_avg = self.absrel_avg / self.total_count + # print(f"Averaging inv depth metrics over {self.total_count} samples") + self.inv_rmse_avg = self.inv_rmse_avg / self.total_count + self.inv_mae_avg = self.inv_mae_avg / self.total_count + self.inv_absrel_avg = self.inv_absrel_avg / self.total_count \ No newline at end of file diff --git a/modules/estimator.py b/modules/estimator.py new file mode 100644 index 0000000..b914c71 --- /dev/null +++ b/modules/estimator.py @@ -0,0 +1,60 @@ +import numpy as np + +def compute_scale_and_shift_ls(prediction, target, mask): + # tuple specifying with axes to sum + sum_axes = (0, 1) + + # system matrix: A = [[a_00, a_01], [a_10, a_11]] + a_00 = np.sum(mask * prediction * prediction, sum_axes) + a_01 = np.sum(mask * prediction, sum_axes) + a_11 = np.sum(mask, sum_axes) + + # right hand side: b = [b_0, b_1] + b_0 = np.sum(mask * prediction * target, sum_axes) + b_1 = np.sum(mask * target, sum_axes) + + # solution: x = A^-1 . b = [[a_11, -a_01], [-a_10, a_00]] / (a_00 * a_11 - a_01 * a_10) . b + x_0 = np.zeros_like(b_0) + x_1 = np.zeros_like(b_1) + + det = a_00 * a_11 - a_01 * a_01 + # A needs to be a positive definite matrix. + valid = det > 0 + + x_0[valid] = (a_11[valid] * b_0[valid] - a_01[valid] * b_1[valid]) / det[valid] + x_1[valid] = (-a_01[valid] * b_0[valid] + a_00[valid] * b_1[valid]) / det[valid] + + return x_0, x_1 + +class LeastSquaresEstimator(object): + def __init__(self, estimate, target, valid): + self.estimate = estimate + self.target = target + self.valid = valid + + # to be computed + self.scale = 1.0 + self.shift = 0.0 + self.output = None + + def compute_scale_and_shift(self): + self.scale, self.shift = compute_scale_and_shift_ls(self.estimate, self.target, self.valid) + + def apply_scale_and_shift(self): + self.output = self.estimate * self.scale + self.shift + + def clamp_min_max(self, clamp_min=None, clamp_max=None): + if clamp_min is not None: + if clamp_min > 0: + clamp_min_inv = 1.0/clamp_min + self.output[self.output > clamp_min_inv] = clamp_min_inv + assert np.max(self.output) <= clamp_min_inv + else: # divide by zero, so skip + pass + if clamp_max is not None: + clamp_max_inv = 1.0/clamp_max + self.output[self.output < clamp_max_inv] = clamp_max_inv + # print(np.min(self.output), clamp_max_inv) + assert np.min(self.output) >= clamp_max_inv + # check for nonzero range + # assert np.min(self.output) != np.max(self.output) \ No newline at end of file diff --git a/modules/interpolator.py b/modules/interpolator.py new file mode 100644 index 0000000..08570e2 --- /dev/null +++ b/modules/interpolator.py @@ -0,0 +1,50 @@ +import numpy as np +np.set_printoptions(suppress=True) + +from scipy.interpolate import griddata + + +def interpolate_knots(map_size, knot_coords, knot_values, interpolate, fill_corners): + grid_x, grid_y = np.mgrid[0:map_size[0], 0:map_size[1]] + + interpolated_map = griddata( + points=knot_coords.T, + values=knot_values, + xi=(grid_y, grid_x), + method=interpolate, + fill_value=1.0) + + return interpolated_map + + +class Interpolator2D(object): + def __init__(self, pred_inv, sparse_depth_inv, valid): + self.pred_inv = pred_inv + self.sparse_depth_inv = sparse_depth_inv + self.valid = valid + + self.map_size = np.shape(pred_inv) + self.num_knots = np.sum(valid) + nonzero_y_loc = np.nonzero(valid)[0] + nonzero_x_loc = np.nonzero(valid)[1] + self.knot_coords = np.stack((nonzero_x_loc, nonzero_y_loc)) + self.knot_scales = sparse_depth_inv[valid] / pred_inv[valid] + self.knot_shifts = sparse_depth_inv[valid] - pred_inv[valid] + + self.knot_list = [] + for i in range(self.num_knots): + self.knot_list.append((int(self.knot_coords[0,i]), int(self.knot_coords[1,i]))) + + # to be computed + self.interpolated_map = None + self.confidence_map = None + self.output = None + + def generate_interpolated_scale_map(self, interpolate_method, fill_corners=False): + self.interpolated_scale_map = interpolate_knots( + map_size=self.map_size, + knot_coords=self.knot_coords, + knot_values=self.knot_scales, + interpolate=interpolate_method, + fill_corners=fill_corners + ).astype(np.float32) \ No newline at end of file diff --git a/modules/midas/base_model.py b/modules/midas/base_model.py new file mode 100644 index 0000000..48dcd72 --- /dev/null +++ b/modules/midas/base_model.py @@ -0,0 +1,26 @@ +import torch + + +class BaseModel(torch.nn.Module): + def load(self, path): + """Load model from file. + + Args: + path (str): file path + """ + parameters = torch.load(path, map_location=torch.device('cpu')) + + if "optimizer" in parameters: + parameters = parameters["model"] + + if "state_dict" in parameters: + state_dict = parameters["state_dict"] + new_state_dict = {} + for key in state_dict.keys(): + if key[0:6] == "model.": + new_state_dict[key[6:]] = state_dict[key] + + self.load_state_dict(new_state_dict) + + else: + self.load_state_dict(parameters) diff --git a/modules/midas/blocks.py b/modules/midas/blocks.py new file mode 100644 index 0000000..fef50db --- /dev/null +++ b/modules/midas/blocks.py @@ -0,0 +1,196 @@ +import torch +import torch.nn as nn + +def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True): + if backbone == "efficientnet_lite3": + pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable) + scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3 + else: + print(f"Backbone '{backbone}' not implemented") + assert False + + return pretrained, scratch + + +def _make_scratch(in_shape, out_shape, groups=1, expand=False): + scratch = nn.Module() + + out_shape1 = out_shape + out_shape2 = out_shape + out_shape3 = out_shape + out_shape4 = out_shape + if expand==True: + out_shape1 = out_shape + out_shape2 = out_shape*2 + out_shape3 = out_shape*4 + out_shape4 = out_shape*8 + + scratch.layer1_rn = nn.Conv2d( + in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer2_rn = nn.Conv2d( + in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer3_rn = nn.Conv2d( + in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer4_rn = nn.Conv2d( + in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + + return scratch + + +def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False): + efficientnet = torch.hub.load( + "rwightman/gen-efficientnet-pytorch", + "tf_efficientnet_lite3", + pretrained=use_pretrained, + exportable=exportable + ) + return _make_efficientnet_backbone(efficientnet) + + +def _make_efficientnet_backbone(effnet): + pretrained = nn.Module() + + pretrained.layer1 = nn.Sequential( + effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2] + ) + pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3]) + pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5]) + pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9]) + + return pretrained + + +class ResidualConvUnit_custom(nn.Module): + """Residual convolution module. + """ + + def __init__(self, features, activation, bn): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.bn = bn + + self.groups=1 + + self.conv1 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups + ) + + self.conv2 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups + ) + + if self.bn==True: + self.bn1 = nn.BatchNorm2d(features) + self.bn2 = nn.BatchNorm2d(features) + + self.activation = activation + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + + out = self.activation(x) + out = self.conv1(out) + if self.bn==True: + out = self.bn1(out) + + out = self.activation(out) + out = self.conv2(out) + if self.bn==True: + out = self.bn2(out) + + if self.groups > 1: + out = self.conv_merge(out) + + return self.skip_add.add(out, x) + + +class FeatureFusionBlock_custom(nn.Module): + """Feature fusion block. + """ + + def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock_custom, self).__init__() + + self.deconv = deconv + self.align_corners = align_corners + + self.groups=1 + + self.expand = expand + out_features = features + if self.expand==True: + out_features = features//2 + + self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1) + + self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn) + self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn) + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, *xs): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + res = self.resConfUnit1(xs[1]) + output = self.skip_add.add(output, res) + + output = self.resConfUnit2(output) + + output = nn.functional.interpolate( + output, scale_factor=2, mode="bilinear", align_corners=self.align_corners + ) + + output = self.out_conv(output) + + return output + + +class OutputConv(nn.Module): + """Output conv block. + """ + + def __init__(self, features, groups, activation, non_negative): + + super(OutputConv, self).__init__() + + self.output_conv = nn.Sequential( + nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=groups), + nn.Upsample(scale_factor=2, mode="bilinear"), + nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1), + activation, + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + nn.Identity(), + ) + + def forward(self, x): + return self.output_conv(x) \ No newline at end of file diff --git a/modules/midas/midas_net_custom.py b/modules/midas/midas_net_custom.py new file mode 100644 index 0000000..53eb14f --- /dev/null +++ b/modules/midas/midas_net_custom.py @@ -0,0 +1,135 @@ +"""MidashNet: Network for monocular depth estimation trained by mixing several datasets. +This file contains code that is adapted from +https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py +""" +import torch +import torch.nn as nn + +from torch.nn import functional as F + +from .base_model import BaseModel +from .blocks import FeatureFusionBlock_custom, _make_encoder, OutputConv + +def weights_init(m): + import math + # initialize from normal (Gaussian) distribution + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2.0 / n)) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + +class MidasNet_small_videpth(BaseModel): + """Network for monocular depth estimation. + """ + + def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=False, exportable=True, channels_last=False, align_corners=True, + blocks={'expand': True}, in_channels=2, regress='r', min_pred=None, max_pred=None): + """Init. + + Args: + path (str, optional): Path to saved model. Defaults to None. + features (int, optional): Number of features. Defaults to 64. + backbone (str, optional): Backbone network for encoder. Defaults to efficientnet_lite3. + """ + print("Loading weights: ", path) + + super(MidasNet_small_videpth, self).__init__() + + use_pretrained = False if path else True + + self.channels_last = channels_last + self.blocks = blocks + self.backbone = backbone + + self.groups = 1 + + # for model output + self.regress = regress + self.min_pred = min_pred + self.max_pred = max_pred + + features1=features + features2=features + features3=features + features4=features + self.expand = False + if "expand" in self.blocks and self.blocks['expand'] == True: + self.expand = True + features1=features + features2=features*2 + features3=features*4 + features4=features*8 + + self.first = nn.Sequential( + nn.Conv2d(in_channels, 3, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(3), + nn.ReLU(inplace=True) + ) + self.first.apply(weights_init) + + self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable) + + self.scratch.activation = nn.ReLU(False) + + self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) + self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) + self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) + self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners) + + self.scratch.output_conv = OutputConv(features, self.groups, self.scratch.activation, non_negative) + + if path: + self.load(path) + + + def forward(self, x, d): + """Forward pass. + + Args: + x (tensor): input data (image) + d (tensor): unalterated input depth + + Returns: + tensor: depth + """ + if self.channels_last==True: + print("self.channels_last = ", self.channels_last) + x.contiguous(memory_format=torch.channels_last) + + layer_0 = self.first(x) + + layer_1 = self.pretrained.layer1(layer_0) + layer_2 = self.pretrained.layer2(layer_1) + layer_3 = self.pretrained.layer3(layer_2) + layer_4 = self.pretrained.layer4(layer_3) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv(path_1) + + scales = F.relu(1.0 + out) + pred = d * scales + + # clamp pred to min and max + if self.min_pred is not None: + min_pred_inv = 1.0/self.min_pred + pred[pred > min_pred_inv] = min_pred_inv + if self.max_pred is not None: + max_pred_inv = 1.0/self.max_pred + pred[pred < max_pred_inv] = max_pred_inv + + # also return scales + return (pred, scales) \ No newline at end of file diff --git a/modules/midas/normalization.py b/modules/midas/normalization.py new file mode 100644 index 0000000..6810e21 --- /dev/null +++ b/modules/midas/normalization.py @@ -0,0 +1,109 @@ +VOID_INTERMEDIATE = { + + "dpt_beit_large_512" : { + "void_150" : { + "mean" : {"int_depth" : 0.730, "int_scales" : 0.380}, + "std" : {"int_depth" : 0.226, "int_scales" : 0.102}, + }, + "void_500" : { + "mean" : {"int_depth" : 0.736, "int_scales" : 0.366}, + "std" : {"int_depth" : 0.232, "int_scales" : 0.099}, + }, + "void_1500" : { + "mean" : {"int_depth" : 0.730, "int_scales" : 0.355}, + "std" : {"int_depth" : 0.232, "int_scales" : 0.096}, + }, + }, + + "dpt_swin2_large_384" : { + "void_150" : { + "mean" : {"int_depth" : 0.730, "int_scales" : 0.402}, + "std" : {"int_depth" : 0.219, "int_scales" : 0.107}, + }, + "void_500" : { + "mean" : {"int_depth" : 0.736, "int_scales" : 0.389}, + "std" : {"int_depth" : 0.224, "int_scales" : 0.106}, + }, + "void_1500" : { + "mean" : {"int_depth" : 0.730, "int_scales" : 0.377}, + "std" : {"int_depth" : 0.226, "int_scales" : 0.103}, + }, + }, + + "dpt_large" : { + "void_150" : { + "mean" : {"int_depth" : 0.729, "int_scales" : 0.403}, + "std" : {"int_depth" : 0.213, "int_scales" : 0.116}, + }, + "void_500" : { + "mean" : {"int_depth" : 0.735, "int_scales" : 0.390}, + "std" : {"int_depth" : 0.219, "int_scales" : 0.116}, + }, + "void_1500" : { + "mean" : {"int_depth" : 0.730, "int_scales" : 0.380}, + "std" : {"int_depth" : 0.221, "int_scales" : 0.116}, + }, + }, + + "dpt_hybrid": { + "void_150" : { + "mean" : {"int_depth" : 0.729, "int_scales" : 0.404}, + "std" : {"int_depth" : 0.210, "int_scales" : 0.117}, + }, + "void_500" : { + "mean" : {"int_depth" : 0.735, "int_scales" : 0.392}, + "std" : {"int_depth" : 0.215, "int_scales" : 0.118}, + }, + "void_1500" : { + "mean" : {"int_depth" : 0.730, "int_scales" : 0.381}, + "std" : {"int_depth" : 0.218, "int_scales" : 0.117}, + }, + }, + + "dpt_swin2_tiny_256" : { + "void_150" : { + "mean" : {"int_depth" : 0.735, "int_scales" : 0.419}, + "std" : {"int_depth" : 0.207, "int_scales" : 0.122}, + }, + "void_500" : { + "mean" : {"int_depth" : 0.741, "int_scales" : 0.406}, + "std" : {"int_depth" : 0.212, "int_scales" : 0.124}, + }, + "void_1500" : { + "mean" : {"int_depth" : 0.733, "int_scales" : 0.396}, + "std" : {"int_depth" : 0.213, "int_scales" : 0.125}, + }, + }, + + "dpt_levit_224" : { + "void_150" : { + "mean" : {"int_depth" : 0.734, "int_scales" : 0.421}, + "std" : {"int_depth" : 0.198, "int_scales" : 0.129}, + }, + "void_500" : { + "mean" : {"int_depth" : 0.740, "int_scales" : 0.410}, + "std" : {"int_depth" : 0.202, "int_scales" : 0.134}, + }, + "void_1500" : { + "mean" : {"int_depth" : 0.734, "int_scales" : 0.400}, + "std" : {"int_depth" : 0.204, "int_scales" : 0.137}, + }, + }, + + "midas_small" : { + "void_150" : { + "mean" : {"int_depth" : 0.723, "int_scales" : 0.402}, + "std" : {"int_depth" : 0.190, "int_scales" : 0.132}, + }, + "void_500" : { + "mean" : {"int_depth" : 0.731, "int_scales" : 0.393}, + "std" : {"int_depth" : 0.196, "int_scales" : 0.136}, + }, + "void_1500" : { + "mean" : {"int_depth" : 0.728, "int_scales" : 0.385}, + "std" : {"int_depth" : 0.199, "int_scales" : 0.140}, + }, + }, + +} + diff --git a/modules/midas/transforms.py b/modules/midas/transforms.py new file mode 100644 index 0000000..38ba7ab --- /dev/null +++ b/modules/midas/transforms.py @@ -0,0 +1,323 @@ +import numpy as np +import cv2 +import math +import torch +import torchvision.transforms as transforms + +from modules.midas.utils import normalize_unit_range +import modules.midas.normalization as normalization + +class Resize(object): + """Resize sample to given size (width, height). + """ + + def __init__( + self, + width, + height, + resize_target=True, + keep_aspect_ratio=False, + ensure_multiple_of=1, + resize_method="lower_bound", + image_interpolation_method=cv2.INTER_AREA, + ): + """Init. + + Args: + width (int): desired output width + height (int): desired output height + resize_target (bool, optional): + True: Resize the full sample (image, mask, target). + False: Resize image only. + Defaults to True. + keep_aspect_ratio (bool, optional): + True: Keep the aspect ratio of the input sample. + Output sample might not have the given width and height, and + resize behaviour depends on the parameter 'resize_method'. + Defaults to False. + ensure_multiple_of (int, optional): + Output width and height is constrained to be multiple of this parameter. + Defaults to 1. + resize_method (str, optional): + "lower_bound": Output will be at least as large as the given size. + "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) + "minimal": Scale as least as possible. (Output size might be smaller than given size.) + Defaults to "lower_bound". + """ + self.__width = width + self.__height = height + + self.__resize_target = resize_target + self.__keep_aspect_ratio = keep_aspect_ratio + self.__multiple_of = ensure_multiple_of + self.__resize_method = resize_method + self.__image_interpolation_method = image_interpolation_method + + def constrain_to_multiple_of(self, x, min_val=0, max_val=None): + y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if max_val is not None and y > max_val: + y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if y < min_val: + y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int) + + return y + + def get_size(self, width, height): + # determine new height and width + scale_height = self.__height / height + scale_width = self.__width / width + + if self.__keep_aspect_ratio: + if self.__resize_method == "lower_bound": + # scale such that output size is lower bound + if scale_width > scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "upper_bound": + # scale such that output size is upper bound + if scale_width < scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "minimal": + # scale as least as possbile + if abs(1 - scale_width) < abs(1 - scale_height): + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + else: + raise ValueError( + f"resize_method {self.__resize_method} not implemented" + ) + + if self.__resize_method == "lower_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, min_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, min_val=self.__width + ) + elif self.__resize_method == "upper_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, max_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, max_val=self.__width + ) + elif self.__resize_method == "minimal": + new_height = self.constrain_to_multiple_of(scale_height * height) + new_width = self.constrain_to_multiple_of(scale_width * width) + else: + raise ValueError(f"resize_method {self.__resize_method} not implemented") + + return (new_width, new_height) + + def __call__(self, sample): + width, height = self.get_size( + sample["image"].shape[1], sample["image"].shape[0] + ) + + # resize sample + for item in sample.keys(): + interpolation_method = self.__image_interpolation_method + sample[item] = cv2.resize( + sample[item], + (width, height), + interpolation=interpolation_method, + ) + + if self.__resize_target: + + if "depth" in sample: + sample["depth"] = cv2.resize( + sample["depth"], + (width, height), + interpolation=cv2.INTER_NEAREST + ) + + if "mask" in sample: + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + (width, height), + interpolation=cv2.INTER_NEAREST, + ) + sample["mask"] = sample["mask"].astype(bool) + + return sample + + +class NormalizeImage(object): + """Normalize image by given mean and std. + """ + + def __init__(self, mean, std): + self.__mean = mean + self.__std = std + + def __call__(self, sample): + sample["image"] = (sample["image"] - self.__mean) / self.__std + + return sample + +class NormalizeIntermediate(object): + """Normalize intermediate data by given mean and std. + """ + + def __init__(self, mean, std): + + self.__int_depth_mean = mean["int_depth"] + self.__int_depth_std = std["int_depth"] + + self.__int_scales_mean = mean["int_scales"] + self.__int_scales_std = std["int_scales"] + + def __call__(self, sample): + + if "int_depth" in sample and sample["int_depth"] is not None: + sample["int_depth"] = (sample["int_depth"] - self.__int_depth_mean) / self.__int_depth_std + + if "int_scales" in sample and sample["int_scales"] is not None: + sample["int_scales"] = (sample["int_scales"] - self.__int_scales_mean) / self.__int_scales_std + + return sample + +class PrepareForNet(object): + """Prepare sample for usage as network input. + """ + + def __init__(self): + pass + + def __call__(self, sample): + + for item in sample.keys(): + + if sample[item] is None: + pass + elif item == "image": + image = np.transpose(sample["image"], (2, 0, 1)) + sample["image"] = np.ascontiguousarray(image).astype(np.float32) + else: + array = sample[item].astype(np.float32) + array = np.expand_dims(array, axis=0) # add channel dim + sample[item] = np.ascontiguousarray(array) + + return sample + + +class Tensorize(object): + """Convert sample to tensor. + """ + + def __init__(self): + pass + + def __call__(self, sample): + + for item in sample.keys(): + + if sample[item] is None: + pass + else: + # before tensorizing, verify that data is clean + assert not np.any(np.isnan(sample[item])) + sample[item] = torch.Tensor(sample[item]) + + return sample + + +def get_transforms(depth_predictor, sparsifier, nsamples): + + image_mean_dict = { + "dpt_beit_large_512" : [0.5, 0.5, 0.5], + "dpt_swin2_large_384" : [0.5, 0.5, 0.5], + "dpt_large" : [0.5, 0.5, 0.5], + "dpt_hybrid" : [0.5, 0.5, 0.5], + "dpt_swin2_tiny_256" : [0.5, 0.5, 0.5], + "dpt_levit_224" : [0.5, 0.5, 0.5], + "midas_small" : [0.485, 0.456, 0.406], + } + + image_std_dict = { + "dpt_beit_large_512" : [0.5, 0.5, 0.5], + "dpt_swin2_large_384" : [0.5, 0.5, 0.5], + "dpt_large" : [0.5, 0.5, 0.5], + "dpt_hybrid" : [0.5, 0.5, 0.5], + "dpt_swin2_tiny_256" : [0.5, 0.5, 0.5], + "dpt_levit_224" : [0.5, 0.5, 0.5], + "midas_small" : [0.229, 0.224, 0.225], + } + + resize_method_dict = { + "dpt_beit_large_512" : "minimal", + "dpt_swin2_large_384" : "minimal", + "dpt_large" : "minimal", + "dpt_hybrid" : "minimal", + "dpt_swin2_tiny_256" : "minimal", + "dpt_levit_224" : "minimal", + "midas_small" : "upper_bound", + } + + resize_dict = { + "dpt_beit_large_512" : 384, + "dpt_swin2_large_384" : 384, + "dpt_large" : 384, + "dpt_hybrid" : 384, + "dpt_swin2_tiny_256" : 256, + "dpt_levit_224" : 224, + "midas_small" : 384, + } + + keep_aspect_ratio = True + if "swin2" in depth_predictor or "levit" in depth_predictor: + keep_aspect_ratio = False + + depth_model_transform_steps = [ + Resize( + width=resize_dict[depth_predictor], + height=resize_dict[depth_predictor], + resize_target=False, + keep_aspect_ratio=keep_aspect_ratio, + ensure_multiple_of=32, + resize_method=resize_method_dict[depth_predictor], + image_interpolation_method=cv2.INTER_CUBIC, + ), + NormalizeImage( + mean=image_mean_dict[depth_predictor], + std=image_std_dict[depth_predictor] + ), + PrepareForNet(), + Tensorize(), + ] + + sml_model_transform_steps = [ + Resize( + width=384, + height=384, + resize_target=False, + keep_aspect_ratio=True, + ensure_multiple_of=32, + resize_method=resize_method_dict["midas_small"], + image_interpolation_method=cv2.INTER_CUBIC, + ), + NormalizeIntermediate( + mean=normalization.VOID_INTERMEDIATE[depth_predictor][f"{sparsifier}_{nsamples}"]["mean"], + std=normalization.VOID_INTERMEDIATE[depth_predictor][f"{sparsifier}_{nsamples}"]["std"], + ), + PrepareForNet(), + Tensorize(), + ] + + return { + "depth_model" : transforms.Compose(depth_model_transform_steps), + "sml_model" : transforms.Compose(sml_model_transform_steps), + } diff --git a/modules/midas/utils.py b/modules/midas/utils.py new file mode 100644 index 0000000..297cdf3 --- /dev/null +++ b/modules/midas/utils.py @@ -0,0 +1,235 @@ +"""Utils for monoDepth. +""" +import sys +import re +import numpy as np +import cv2 +import torch + + +def read_pfm(path): + """Read pfm file. + + Args: + path (str): path to file + + Returns: + tuple: (data, scale) + """ + with open(path, "rb") as file: + + color = None + width = None + height = None + scale = None + endian = None + + header = file.readline().rstrip() + if header.decode("ascii") == "PF": + color = True + elif header.decode("ascii") == "Pf": + color = False + else: + raise Exception("Not a PFM file: " + path) + + dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii")) + if dim_match: + width, height = list(map(int, dim_match.groups())) + else: + raise Exception("Malformed PFM header.") + + scale = float(file.readline().decode("ascii").rstrip()) + if scale < 0: + # little-endian + endian = "<" + scale = -scale + else: + # big-endian + endian = ">" + + data = np.fromfile(file, endian + "f") + shape = (height, width, 3) if color else (height, width) + + data = np.reshape(data, shape) + data = np.flipud(data) + + return data, scale + + +def write_pfm(path, image, scale=1): + """Write pfm file. + + Args: + path (str): pathto file + image (array): data + scale (int, optional): Scale. Defaults to 1. + """ + + with open(path, "wb") as file: + color = None + + if image.dtype.name != "float32": + raise Exception("Image dtype must be float32.") + + image = np.flipud(image) + + if len(image.shape) == 3 and image.shape[2] == 3: # color image + color = True + elif ( + len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1 + ): # greyscale + color = False + else: + raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.") + + file.write("PF\n" if color else "Pf\n".encode()) + file.write("%d %d\n".encode() % (image.shape[1], image.shape[0])) + + endian = image.dtype.byteorder + + if endian == "<" or endian == "=" and sys.byteorder == "little": + scale = -scale + + file.write("%f\n".encode() % scale) + + image.tofile(file) + + +def read_image(path): + """Read image and output RGB image (0-1). + + Args: + path (str): path to file + + Returns: + array: RGB image (0-1) + """ + img = cv2.imread(path) + + if img.ndim == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0 + + return img + + +def resize_image(img): + """Resize image and make it fit for network. + + Args: + img (array): image + + Returns: + tensor: data ready for network + """ + height_orig = img.shape[0] + width_orig = img.shape[1] + + if width_orig > height_orig: + scale = width_orig / 384 + else: + scale = height_orig / 384 + + height = (np.ceil(height_orig / scale / 32) * 32).astype(int) + width = (np.ceil(width_orig / scale / 32) * 32).astype(int) + + img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA) + + img_resized = ( + torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float() + ) + img_resized = img_resized.unsqueeze(0) + + return img_resized + + +def resize_depth(depth, width, height): + """Resize depth map and bring to CPU (numpy). + + Args: + depth (tensor): depth + width (int): image width + height (int): image height + + Returns: + array: processed depth + """ + depth = torch.squeeze(depth[0, :, :, :]).to("cpu") + + depth_resized = cv2.resize( + depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC + ) + + return depth_resized + + +def write_depth(path, depth, bits=1): + """Write depth map to pfm and png file. + + Args: + path (str): filepath without extension + depth (array): depth + """ + write_pfm(path + ".pfm", depth.astype(np.float32)) + + depth_min = depth.min() + depth_max = depth.max() + + max_val = (2**(8*bits))-1 + + if depth_max - depth_min > np.finfo("float").eps: + out = max_val * (depth - depth_min) / (depth_max - depth_min) + else: + out = np.zeros(depth.shape, dtype=depth.type) + + if bits == 1: + cv2.imwrite(path + ".png", out.astype("uint8")) + elif bits == 2: + cv2.imwrite(path + ".png", out.astype("uint16")) + + return + + +def write_png(path, array, bits=2): + """Write array to png file. + + Args: + path (str): filepath without extension + array (array): array to be saved + """ + + array_min = np.min(array) + array_max = np.max(array) + + max_val = (2**(8*bits))-1 + + if array_max - array_min > np.finfo("float").eps: + out = max_val * (array - array_min) / (array_max - array_min) + else: + print(f"zero array not being saved at {path}") + return + + if bits == 1: + cv2.imwrite(path + ".png", out.astype("uint8")) + elif bits == 2: + cv2.imwrite(path + ".png", out.astype("uint16")) + + return + + +def normalize_unit_range(data): + """Normalize data array to [0, 1] range. + + Args: + data (array): input array + + Returns: + array: normalized array + """ + if np.max(data) - np.min(data) > np.finfo("float").eps: + normalized = (data - np.min(data)) / (np.max(data) - np.min(data)) + else: + raise ValueError("cannot normalize array, max-min range is 0") + + return normalized \ No newline at end of file diff --git a/output/ga_depth/.placeholder b/output/ga_depth/.placeholder new file mode 100644 index 0000000..e69de29 diff --git a/output/sml_depth/.placeholder b/output/sml_depth/.placeholder new file mode 100644 index 0000000..e69de29 diff --git a/pipeline.py b/pipeline.py new file mode 100644 index 0000000..3acf0af --- /dev/null +++ b/pipeline.py @@ -0,0 +1,141 @@ +import torch +import numpy as np + +from modules.midas.midas_net_custom import MidasNet_small_videpth +from modules.estimator import LeastSquaresEstimator +from modules.interpolator import Interpolator2D + +import modules.midas.transforms as transforms +import modules.midas.utils as utils + +class VIDepth(object): + def __init__(self, depth_predictor, nsamples, sml_model_path, + min_pred, max_pred, min_depth, max_depth, device): + + # get transforms + model_transforms = transforms.get_transforms(depth_predictor, "void", str(nsamples)) + self.depth_model_transform = model_transforms["depth_model"] + self.ScaleMapLearner_transform = model_transforms["sml_model"] + + # define depth model + if depth_predictor == "dpt_beit_large_512": + self.DepthModel = torch.hub.load("intel-isl/MiDaS", "DPT_BEiT_L_512") + elif depth_predictor == "dpt_swin2_large_384": + self.DepthModel = torch.hub.load("intel-isl/MiDaS", "DPT_SwinV2_L_384") + elif depth_predictor == "dpt_large": + self.DepthModel = torch.hub.load("intel-isl/MiDaS", "DPT_Large") + elif depth_predictor == "dpt_hybrid": + self.DepthModel = torch.hub.load("intel-isl/MiDaS", "DPT_Hybrid") + elif depth_predictor == "dpt_swin2_tiny_256": + self.DepthModel = torch.hub.load("intel-isl/MiDaS", "DPT_SwinV2_T_256") + elif depth_predictor == "dpt_levit_224": + self.DepthModel = torch.hub.load("intel-isl/MiDaS", "DPT_LeViT_224") + elif depth_predictor == "midas_small": + self.DepthModel = torch.hub.load("intel-isl/MiDaS", "MiDaS_small") + else: + self.DepthModel = None + + # define SML model + self.ScaleMapLearner = MidasNet_small_videpth( + path=sml_model_path, + min_pred=min_pred, + max_pred=max_pred, + ) + + # depth prediction ranges + self.min_pred, self.max_pred = min_pred, max_pred + + # depth evaluation ranges + self.min_depth, self.max_depth = min_depth, max_depth + + # eval mode + self.DepthModel.eval() + self.DepthModel.to(device) + + # eval mode + self.ScaleMapLearner.eval() + self.ScaleMapLearner.to(device) + + + def run(self, input_image, input_sparse_depth, validity_map, device): + + input_height, input_width = np.shape(input_image)[0], np.shape(input_image)[1] + + sample = {"image" : input_image} + sample = self.depth_model_transform(sample) + im = sample["image"].to(device) + + input_sparse_depth_valid = (input_sparse_depth < self.max_depth) * (input_sparse_depth > self.min_depth) + if validity_map is not None: + input_sparse_depth_valid *= validity_map.astype(np.bool) + + input_sparse_depth_valid = input_sparse_depth_valid.astype(bool) + input_sparse_depth[~input_sparse_depth_valid] = np.inf # set invalid depth + input_sparse_depth = 1.0 / input_sparse_depth + + # run depth model + with torch.no_grad(): + depth_pred = self.DepthModel.forward(im.unsqueeze(0)) + depth_pred = ( + torch.nn.functional.interpolate( + depth_pred.unsqueeze(1), + size=(input_height, input_width), + mode="bicubic", + align_corners=False, + ) + .squeeze() + .cpu() + .numpy() + ) + + # global scale and shift alignment + GlobalAlignment = LeastSquaresEstimator( + estimate=depth_pred, + target=input_sparse_depth, + valid=input_sparse_depth_valid + ) + GlobalAlignment.compute_scale_and_shift() + GlobalAlignment.apply_scale_and_shift() + GlobalAlignment.clamp_min_max(clamp_min=self.min_pred, clamp_max=self.max_pred) + int_depth = GlobalAlignment.output.astype(np.float32) + + # interpolation of scale map + assert (np.sum(input_sparse_depth_valid) >= 3), "not enough valid sparse points" + ScaleMapInterpolator = Interpolator2D( + pred_inv = int_depth, + sparse_depth_inv = input_sparse_depth, + valid = input_sparse_depth_valid, + ) + ScaleMapInterpolator.generate_interpolated_scale_map( + interpolate_method='linear', + fill_corners=False + ) + int_scales = ScaleMapInterpolator.interpolated_scale_map.astype(np.float32) + int_scales = utils.normalize_unit_range(int_scales) + + sample = {"image" : input_image, "int_depth" : int_depth, "int_scales" : int_scales, "int_depth_no_tf" : int_depth} + sample = self.ScaleMapLearner_transform(sample) + x = torch.cat([sample["int_depth"], sample["int_scales"]], 0) + x = x.to(device) + d = sample["int_depth_no_tf"].to(device) + + # run SML model + with torch.no_grad(): + sml_pred, sml_scales = self.ScaleMapLearner.forward(x.unsqueeze(0), d.unsqueeze(0)) + sml_pred = ( + torch.nn.functional.interpolate( + sml_pred, + size=(input_height, input_width), + mode="bicubic", + align_corners=False, + ) + .squeeze() + .cpu() + .numpy() + ) + + output = { + "ga_depth" : int_depth, + "sml_depth" : sml_pred, + } + return output \ No newline at end of file diff --git a/run.py b/run.py new file mode 100644 index 0000000..dc3bc47 --- /dev/null +++ b/run.py @@ -0,0 +1,125 @@ +import os +import argparse +import glob + +import torch +import numpy as np + +from PIL import Image + +import modules.midas.utils as utils + +import pipeline + + +def load_input_image(input_image_fp): + return utils.read_image(input_image_fp) + + +def load_sparse_depth(input_sparse_depth_fp): + input_sparse_depth = np.array(Image.open(input_sparse_depth_fp), dtype=np.float32) / 256.0 + input_sparse_depth[input_sparse_depth <= 0] = 0.0 + return input_sparse_depth + + +def run(depth_predictor, nsamples, sml_model_path, + min_pred, max_pred, min_depth, max_depth, + input_path, output_path, save_output): + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print("device: %s" % device) + + # instantiate method + method = pipeline.VIDepth( + depth_predictor, nsamples, sml_model_path, + min_pred, max_pred, min_depth, max_depth, device + ) + + # get inputs + img_names = glob.glob(os.path.join(input_path, "image", "*")) + num_images = len(img_names) + + # create output folders + if save_output: + os.makedirs(os.path.join(output_path, 'ga_depth'), exist_ok=True) + os.makedirs(os.path.join(output_path, 'sml_depth'), exist_ok=True) + + for ind, input_image_fp in enumerate(img_names): + if os.path.isdir(input_image_fp): + continue + + print(" processing {} ({}/{})".format(input_image_fp, ind + 1, num_images)) + + input_image = load_input_image(input_image_fp) + + input_sparse_depth_fp = input_image_fp.replace("image", "sparse_depth") + input_sparse_depth = load_sparse_depth(input_sparse_depth_fp) + + # values in the [min_depth, max_depth] range are considered valid; + # an additional validity map may be specified + validity_map = None + + # run method + output = method.run(input_image, input_sparse_depth, validity_map, device) + + if save_output: + basename = os.path.splitext(os.path.basename(input_image_fp))[0] + + # saving depth map after global alignment + utils.write_depth( + os.path.join(output_path, 'ga_depth', basename), + output["ga_depth"], bits=2 + ) + + # saving depth map after local alignment with SML + utils.write_depth( + os.path.join(output_path, 'sml_depth', basename), + output["sml_depth"], bits=2 + ) + +if __name__=="__main__": + + parser = argparse.ArgumentParser() + + # model parameters + parser.add_argument('-dp', '--depth-predictor', type=str, default='dpt_hybrid', + help='Name of depth predictor to use in pipeline.') + parser.add_argument('-ns', '--nsamples', type=int, default=150, + help='Number of sparse metric depth samples available.') + parser.add_argument('-sm', '--sml-model-path', type=str, default='', + help='Path to trained SML model weights.') + + # depth parameters + parser.add_argument('--min-pred', type=float, default=0.1, + help='Min bound for predicted depth values.') + parser.add_argument('--max-pred', type=float, default=8.0, + help='Max bound for predicted depth values.') + parser.add_argument('--min-depth', type=float, default=0.2, + help='Min valid depth when evaluating.') + parser.add_argument('--max-depth', type=float, default=5.0, + help='Max valid depth when evaluating.') + + # I/O paths + parser.add_argument('-i', '--input-path', type=str, default='./input', + help='Path to inputs.') + parser.add_argument('-o', '--output-path', type=str, default='./output', + help='Path to outputs.') + parser.add_argument('--save-output', dest='save_output', action='store_true', + help='Save output depth map.') + parser.set_defaults(save_output=False) + + args = parser.parse_args() + print(args) + + run( + args.depth_predictor, + args.nsamples, + args.sml_model_path, + args.min_pred, + args.max_pred, + args.min_depth, + args.max_depth, + args.input_path, + args.output_path, + args.save_output + ) \ No newline at end of file diff --git a/weights/.placeholder b/weights/.placeholder new file mode 100644 index 0000000..e69de29