Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

2DGS #208

Merged
merged 118 commits into from
Sep 12, 2024
Merged

2DGS #208

Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
118 commits
Select commit Hold shift + click to select a range
086dfe0
cuda legacy with 2DGS
FantasticOven2 Jun 8, 2024
906a955
clean up changes
FantasticOven2 Jun 9, 2024
b288462
add torch implementation
FantasticOven2 Jun 9, 2024
436b424
adding _wrapper for cuda legacy
FantasticOven2 Jun 9, 2024
9a549d4
working on forward (projection)
FantasticOven2 Jun 9, 2024
5764be5
working on forward (rasterization)
FantasticOven2 Jun 10, 2024
cd7786f
rasterize forward continue
FantasticOven2 Jun 10, 2024
f408b13
starting backward (rasterization)
FantasticOven2 Jun 11, 2024
8580e5e
stubs for 2dgs testing
Jun 11, 2024
e596d75
Merge branch 'main' of https://github.com/FantasticOven2/gsplat
FantasticOven2 Jun 12, 2024
fecabc5
ctn.rasterize (backward)
FantasticOven2 Jun 12, 2024
151ee3d
start projection backward
FantasticOven2 Jun 13, 2024
8cb5154
setup simple 2D experiments and debug forward
FantasticOven2 Jun 14, 2024
9a92cca
ctn. forward debugging
FantasticOven2 Jun 15, 2024
d6ead12
ctn. forward debugging
FantasticOven2 Jun 15, 2024
449e1a4
2d example works
FantasticOven2 Jun 15, 2024
af66b5e
add slant view
FantasticOven2 Jun 15, 2024
4c3717c
variable num points for 2D example
FantasticOven2 Jun 15, 2024
f15625b
working on bwd
FantasticOven2 Jun 17, 2024
39ca680
working on backward
FantasticOven2 Jun 17, 2024
966ec1f
working on backward
FantasticOven2 Jun 17, 2024
9530c89
working on bwd
FantasticOven2 Jun 17, 2024
79298a4
continue adding tests and torch impl for 2dgs
Jun 18, 2024
6a0c8c9
pipeline running nning
FantasticOven2 Jun 18, 2024
d7024ec
debugging gradient
FantasticOven2 Jun 22, 2024
3b19b70
image fitting working
FantasticOven2 Jun 23, 2024
f484aa0
3D works
FantasticOven2 Jun 24, 2024
3a00b5f
setting up numerical tests
FantasticOven2 Jun 27, 2024
37a9f3e
add projection numerical test
FantasticOven2 Jun 27, 2024
f2248b6
adding torch impl and tests for rasterization
Jun 28, 2024
0473dc9
updating 2dgs numerical tests -- rasterize_indices_in_range_2dgs is b…
Jun 28, 2024
84ac77a
correct rasterize_to_indices_in_range_2dgs
FantasticOven2 Jun 28, 2024
3909121
working on 2dgs rasterization numerical tests
FantasticOven2 Jun 28, 2024
7fabcf4
resolving rasterization numerical test differences
FantasticOven2 Jun 28, 2024
5ab8391
linter
Jun 29, 2024
406cd6f
migrating main PR 240 for background
Jun 29, 2024
b0b5da1
fixed bug in rasterize_to_indices_2dgs
Jun 29, 2024
0ff18e2
still failing forward but diffs are much smaller ~2e-5
Jun 29, 2024
1d245be
adding gitignore and removing committed images
Jun 29, 2024
39b6e1a
adding backward tests, currently all failing, will look later
Jun 29, 2024
1b25093
fixed viewmat, pass the correct viewmat, not its transpose
FantasticOven2 Jun 29, 2024
fd36aa9
added densification
FantasticOven2 Jun 30, 2024
4c11611
start normal rendering and potential deadlock
FantasticOven2 Jul 1, 2024
b37f171
normal dual visible
FantasticOven2 Jul 2, 2024
780936e
depth to normal
FantasticOven2 Jul 2, 2024
3cd6690
start normal loss gradient
FantasticOven2 Jul 3, 2024
4d68c74
normal loss gradient
FantasticOven2 Jul 4, 2024
37106b8
ctn. normal consistency gradient
FantasticOven2 Jul 4, 2024
9d47e82
normal gradient works, need to debug the output normal
FantasticOven2 Jul 5, 2024
30108ec
normal consistency done
FantasticOven2 Jul 5, 2024
b55ea01
adding depth distortion loss
FantasticOven2 Jul 7, 2024
76fdf71
added L1 distortion
FantasticOven2 Jul 7, 2024
ca9f0e1
fixed normal bug
FantasticOven2 Jul 9, 2024
7101d5f
working on numerical tests, now only v_ray_Ms didn't pass
FantasticOven2 Jul 11, 2024
164c1bc
find the potential deadlock bug
FantasticOven2 Jul 15, 2024
e4646dd
passing basic numerical tests
FantasticOven2 Jul 16, 2024
0bb01b7
working on halting problem
FantasticOven2 Jul 21, 2024
6574f15
fixed hanging problem due to thread sync
FantasticOven2 Jul 21, 2024
6e7b302
code optimization
FantasticOven2 Jul 23, 2024
bb508c0
add normal tests
FantasticOven2 Jul 24, 2024
7afd909
save changes
FantasticOven2 Jul 29, 2024
56d0a81
improve densification gradient, now working on migration to main brac…
FantasticOven2 Jul 30, 2024
6c1dee0
save changes
FantasticOven2 Aug 4, 2024
8bcf6ab
update to main version
FantasticOven2 Aug 4, 2024
eebe72b
clean tests
FantasticOven2 Aug 4, 2024
bc4fdfe
adding packed projection
FantasticOven2 Aug 5, 2024
091fc71
adding distortion
FantasticOven2 Aug 7, 2024
0a62fe8
adding tests for packed version, forward passed, backward bug
FantasticOven2 Aug 7, 2024
7e86657
cleanup
FantasticOven2 Aug 7, 2024
65ce16d
remove submodule
FantasticOven2 Aug 7, 2024
c79a32e
all tests passed
FantasticOven2 Aug 7, 2024
c31a4dc
added packed tests
FantasticOven2 Aug 8, 2024
2a8f0d6
sync with main
FantasticOven2 Aug 9, 2024
7784955
update readme
FantasticOven2 Aug 9, 2024
2b350c9
update default reg params
FantasticOven2 Aug 9, 2024
8345c5d
Merge branch 'main' into 2dgs
Aug 12, 2024
a006f94
remove mcmc script
Aug 12, 2024
c82ce96
minor cleanup
Aug 12, 2024
353fc27
black format
Aug 12, 2024
6e2cc36
fixed bug in packed version, working on cleanup
FantasticOven2 Aug 15, 2024
fd7c866
adding median depth
FantasticOven2 Aug 16, 2024
e0f2879
support median depth
FantasticOven2 Aug 17, 2024
c2420a8
backup
FantasticOven2 Aug 20, 2024
666705b
cleanup
FantasticOven2 Aug 21, 2024
d57b75f
cleaned up most of the code
FantasticOven2 Aug 21, 2024
a57c525
more cleanup
FantasticOven2 Aug 22, 2024
6dfd704
add docstring
FantasticOven2 Aug 22, 2024
48abf70
adding argument for densification gradient variable
FantasticOven2 Aug 23, 2024
2e8da2e
done cleanup
FantasticOven2 Aug 23, 2024
c538b51
update return type
FantasticOven2 Aug 28, 2024
7daa1ef
additional cleanup
FantasticOven2 Aug 28, 2024
d34baa2
fixed floaters in normal rendering
FantasticOven2 Aug 29, 2024
2f9eb75
adding docs
FantasticOven2 Sep 5, 2024
5566aa5
adding docs
FantasticOven2 Sep 5, 2024
bade5bf
generate teaser vid
FantasticOven2 Sep 5, 2024
afe0308
all done
FantasticOven2 Sep 6, 2024
1a91b03
update to align with main
FantasticOven2 Sep 6, 2024
7bb5af9
test passes
Sep 8, 2024
b1697c2
format
Sep 8, 2024
408f16e
cleanup torch_impl
Sep 8, 2024
6987c00
Merge branch 'main' into 2dgs
Sep 8, 2024
217fdc8
cleanup utils and rasterization fn
Sep 8, 2024
896cc45
black format
Sep 8, 2024
4630230
script cleanup
Sep 8, 2024
61ce019
image fitting cleanup
Sep 8, 2024
c5edf16
image fitting cleanup
Sep 8, 2024
4176666
add matplotlib to requirements
Sep 8, 2024
935cb1d
cleanup
FantasticOven2 Sep 8, 2024
cf6b32f
resolve ruilong's note
FantasticOven2 Sep 9, 2024
c5d1449
rename ray_M to ray_transform
FantasticOven2 Sep 9, 2024
75de227
update doc
FantasticOven2 Sep 9, 2024
8f39645
update doc
FantasticOven2 Sep 9, 2024
14a097a
adding docstring
FantasticOven2 Sep 9, 2024
9d80b19
remove renders
FantasticOven2 Sep 11, 2024
24add4d
Merge branch 'main' into main
liruilong940607 Sep 11, 2024
3487c74
black formatting
FantasticOven2 Sep 12, 2024
fd219f8
black format 22.3
FantasticOven2 Sep 12, 2024
8c4efba
remove renders folder in examples
FantasticOven2 Sep 12, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 36 additions & 1 deletion docs/source/apis/rasterization.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
Rasterization
===================================

3DGS
------

.. currentmodule:: gsplat

Given a set of 3D gaussians parametrized by means :math:`\mu \in \mathbb{R}^3`, covariances
Expand Down Expand Up @@ -38,4 +41,36 @@ projection equation:
Where :math:`[W | t]` is the world-to-camera transformation matrix, and :math:`f_{x}, f_{y}`
are the focal lengths of the camera.

.. autofunction:: rasterization
.. autofunction:: rasterization

2DGS
------

Given a set of 2D gaussians parametrized by means :math:`\mu \in \mathbb{R}^3`, two principal tangent vectors
embedded as the first two columns of a rotation matrix :math:`R \in \mathbb{R}^{3\times3}`, and a scale matrix :math:`S \in R^{3\times3}`
representing the scaling along the two principal tangential directions, we first transforms pixels into splats' local tangent frame
by :math:`(WH)^{-1} \in \mathbb{R}^{4\times4}` and compute weights via ray-splat intersection. Then we follow the sort and rendering similar to 3DGS.

Note that H is the transformation from splat's local tangent plane :math:`\{u, v\}` into world space

.. math::

H = \begin{bmatrix}
RS & \mu \\
0 & 1
\end{bmatrix}

and :math:`W \in \mathbb{R}^{4\times4}` is the transformation matrix from world space to image space.


Splatting is done via ray-splat plane intersection. Each pixel is considered as a x-plane :math:`h_{x}=(-1, 0, 0, x)^{T}`
and a y-plane :math:`h_{y}=(0, -1, 0, y)^{T}`, and the intersection between a splat and the pixel :math:`p=(x, y)` is defined
as the intersection bwtween x-plane, y-plane, and the splat's tangent plane. We first transform :math:`h_{x}` to :math:`h_{u}` and :math:`h_{y}`
to :math:`h_{v}` in splat's tangent frame via the inverse transformation :math:`(WH)^{-1}`. As the intersection point should fall on :math:`h_{u}` and :math:`h_{v}`, we have an efficient
solution:

.. math::
u(p) = \frac{h^{2}_{u}h^{4}_{v}-h^{4}_{u}h^{2}_{v}}{h^{1}_{u}h^{2}_{v}-h^{2}_{u}h^{1}_{v}},
v(p) = \frac{h^{4}_{u}h^{1}_{v}-h^{1}_{u}h^{4}_{v}}{h^{1}_{u}h^{2}_{v}-h^{2}_{u}h^{1}_{v}}

.. autofunction:: rasterization_2dgs
17 changes: 17 additions & 0 deletions docs/source/apis/utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ Utils

Below are the basic functions that supports the rasterization.

3DGS
-----

.. currentmodule:: gsplat

.. autofunction:: spherical_harmonics
Expand All @@ -27,3 +30,17 @@ Below are the basic functions that supports the rasterization.
.. autofunction:: accumulate

.. autofunction:: rasterization_inria_wrapper

2DGS
-----
.. currentmodule:: gsplat

.. autofunction:: fully_fused_projection_2dgs

.. autofunction:: rasterize_to_pixels_2dgs

.. autofunction:: rasterize_to_indices_in_range_2dgs

.. autofunction:: accumulate_2dgs

.. autofunction:: rasterization_2dgs_inria_wrapper
100 changes: 100 additions & 0 deletions docs/source/tests/eval.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
Evaluation
===================================

3DGS
----------------------------------------------

.. table:: Performance on `Mip-NeRF 360 Captures <https://jonbarron.info/mipnerf360/>`_ (Averaged Over 7 Scenes)

+---------------------+-------+-------+-------+------------------+------------+
Expand Down Expand Up @@ -140,3 +143,100 @@ The evaluation of `inria-X` can be
reproduced with our forked wersion of the official implementation at
`here <https://github.com/liruilong940607/gaussian-splatting/tree/benchmark>`_,
with the command :code:`python full_eval_m360.py` (commit 36546ce).

2DGS
----------------------------------------------

No Regularization
----------------------------------------------

.. table:: Performance on `Mip-NeRF 360 Captures <https://jonbarron.info/mipnerf360/>`_ (Averaged Over 7 Scenes)

+---------------------+-------+-------+-------+------------------+------------+
| | PSNR | SSIM | LPIPS | Train Mem | Train Time |
+=====================+=======+=======+=======+==================+============+
| inria-30k | 28.73 | 0.860 | 0.148 | 3.73 GB | 22m16s |
+---------------------+-------+-------+-------+------------------+------------+
| gsplat-30k | 28.76 | 0.867 | 0.145 | **3.70 GB** | **15m44s** |
+---------------------+-------+-------+-------+------------------+------------+

With Normal Consistency and Distortion Regularization
------------------------------------------------------

+---------------------+-------+-------+-------+------------------+------------+
| | PSNR | SSIM | LPIPS | Train Mem | Train Time |
+=====================+=======+=======+=======+==================+============+
| inria-30k | 28.05 | 0.848 | 0.186 | 3.76 GB | 22m06s |
+---------------------+-------+-------+-------+------------------+------------+
| gsplat-30k | 27.80 | 0.842 | 0.169 | **3.61 GB** | **16m44s** |
+---------------------+-------+-------+-------+------------------+------------+

Runtime and GPU Memory
----------------------------------------------

+-----------------+---------+--------+---------+--------+---------+--------+--------+
| Train Mem (GB) | Bicycle | Bonsai | Counter | Garden | Kitchen | Room | Stump |
+=================+=========+========+=========+========+=========+========+========+
| inria-30k |**6.74** | 2.27 | 2.06 | 4.79 | 2.25 | 2.40 |**5.58**|
+-----------------+---------+--------+---------+--------+---------+--------+--------+
| gsplat-30k | 6.89 |**2.19**| **1.93**|**4.48**| **2.14**|**2.30**| 6.00 |
+-----------------+---------+--------+---------+--------+---------+--------+--------+

+-----------------+---------+--------+---------+--------+---------+--------+--------+
| Train Time (s) | Bicycle | Bonsai | Counter | Garden | Kitchen | Room | Stump |
+=================+=========+========+=========+========+=========+========+========+
| inria-30k | 1463 | 1237 | 1318 | 1298 | 1422 | 1314 | 1252 |
+-----------------+---------+--------+---------+--------+---------+--------+--------+
| gsplat-30k |**1231** |**788** | **803**| **985**| **828**| **789**|**1057**|
+-----------------+---------+--------+---------+--------+---------+--------+--------+


Reproduced Metrics
----------------------------------------------

+------------+---------+--------+---------+--------+---------+-------+-------+
| PSNR | Bicycle | Bonsai | Counter | Garden | Kitchen | Room | Stump |
+============+=========+========+=========+========+=========+=======+=======+
| inria-30k | 24.92 | 31.87 | 28.78 | 26.88 | 31.08 | 31.21 | 26.36 |
+------------+---------+--------+---------+--------+---------+-------+-------+
| gsplat-30k | 24.97 | 31.94 | 28.76 | 26.95 | 31.08 | 31.27 | 26.37 |
+------------+---------+--------+---------+--------+---------+-------+-------+

+------------+---------+--------+---------+--------+---------+-------+-------+
| SSIM | Bicycle | Bonsai | Counter | Garden | Kitchen | Room | Stump |
+============+=========+========+=========+========+=========+=======+=======+
| inria-30k | 0.741 | 0.937 | 0.899 | 0.847 | 0.921 | 0.914 | 0.760 |
+------------+---------+--------+---------+--------+---------+-------+-------+
| gsplat-30k | 0.764 | 0.937 | 0.899 | 0.849 | 0.921 | 0.915 | 0.761 |
+------------+---------+--------+---------+--------+---------+-------+-------+

+------------+---------+--------+---------+--------+---------+-------+-------+
| LPIPS | Bicycle | Bonsai | Counter | Garden | Kitchen | Room | Stump |
+============+=========+========+=========+========+=========+=======+=======+
| inria-30k | 0.199 | 0.136 | 0.164 | 0.093 | 0.101 | 0.172 | 0.168 |
+------------+---------+--------+---------+--------+---------+-------+-------+
| gsplat-30k | 0.189 | 0.134 | 0.162 | 0.091 | 0.101 | 0.169 | 0.166 |
+------------+---------+--------+---------+--------+---------+-------+-------+

+-----------------+---------+--------+---------+--------+---------+-------+-------+
| Number of GSs | Bicycle | Bonsai | Counter | Garden | Kitchen | Room | Stump |
+=================+=========+========+=========+========+=========+=======+=======+
| inria-30k | 3.97M | 0.91M | 0.72M | 2.79M | 0.85M | 1.01M | 3.27M |
+-----------------+---------+--------+---------+--------+---------+-------+-------+
| gsplat-30k | 3.88M | 0.92M | 0.73M | 2.49M | 0.87M | 1.03M | 3.40M |
+-----------------+---------+--------+---------+--------+---------+-------+-------+

Note: Evaulations for 2DGS are conducted on a NVIDIA RTX 4090 GPU. The LPIPS metric is evaluated
using :code:`from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity`, which
is different from what's reported in the original paper that uses
:code:`from lpipsPyTorch import lpips`.

The evaluation of `gsplat-X` can be reproduced with the command
:code:`cd examples; bash benchmarks/basic_2dgs.sh`
within the gsplat repo (commit 48abf70).

The evaluation of `inria-X` can be
reproduced with our forked wersion of the official implementation at
`here <https://github.com/hbb1/diff-surfel-rasterization>`_;
you need to change the :code:`--model_type 2dgs` to :code:`--model_type 2dgs-inria` in
:code:`benchmars/basic_2dgs` and run command :code:`cd examples; bash benchmarks/basic_2dgs.sh` (commit 28c928a).
52 changes: 52 additions & 0 deletions examples/benchmarks/basic_2dgs.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
SCENE_DIR="data/360_v2"
RESULT_DIR="results/benchmark_2dgs"
SCENE_LIST="garden bicycle stump bonsai counter kitchen room" # treehill flowers

for SCENE in $SCENE_LIST;
do
if [ "$SCENE" = "bonsai" ] || [ "$SCENE" = "counter" ] || [ "$SCENE" = "kitchen" ] || [ "$SCENE" = "room" ]; then
DATA_FACTOR=2
else
DATA_FACTOR=4
fi

echo "Running $SCENE"

# train without eval
CUDA_VISIBLE_DEVICES=0 python simple_trainer_2dgs.py --eval_steps -1 --disable_viewer --data_factor $DATA_FACTOR \
--model_type 2dgs \
--data_dir data/360_v2/$SCENE/ \
--result_dir $RESULT_DIR/$SCENE/

# run eval and render
for CKPT in $RESULT_DIR/$SCENE/ckpts/*;
do
CUDA_VISIBLE_DEVICES=0 python simple_trainer_2dgs.py --disable_viewer --data_factor $DATA_FACTOR \
--model_type 2dgs \
--data_dir data/360_v2/$SCENE/ \
--result_dir $RESULT_DIR/$SCENE/ \
--ckpt $CKPT
done
done


for SCENE in $SCENE_LIST;
do
echo "=== Eval Stats ==="

for STATS in $RESULT_DIR/$SCENE/stats/val*.json;
do
echo $STATS
cat $STATS;
echo
done

echo "=== Train Stats ==="

for STATS in $RESULT_DIR/$SCENE/stats/train*_rank0.json;
do
echo $STATS
cat $STATS;
echo
done
done
20 changes: 15 additions & 5 deletions examples/image_fitting.py
liruilong940607 marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@
import os
import time
from pathlib import Path
from typing import Optional
from typing import Literal, Optional

import numpy as np
import torch
import tyro
from PIL import Image
from torch import Tensor, optim

from gsplat import rasterization
from gsplat import rasterization, rasterization_2dgs


class SimpleTrainer:
Expand Down Expand Up @@ -79,6 +79,7 @@ def train(
iterations: int = 1000,
lr: float = 0.01,
save_imgs: bool = False,
model_type: Literal["3dgs", "2dgs"] = "3dgs",
liruilong940607 marked this conversation as resolved.
Show resolved Hide resolved
):
optimizer = optim.Adam(
[self.rgbs, self.means, self.scales, self.opacities, self.quats], lr
Expand All @@ -94,9 +95,16 @@ def train(
],
device=self.device,
)

if model_type == "3dgs":
rasterize_fnc = rasterization
elif model_type == "2dgs":
rasterize_fnc = rasterization_2dgs

for iter in range(iterations):
start = time.time()
renders, _, _ = rasterization(

renders = rasterize_fnc(
self.means,
self.quats / self.quats.norm(dim=-1, keepdim=True),
self.scales,
Expand All @@ -107,7 +115,7 @@ def train(
self.W,
self.H,
packed=False,
)
)[0]
out_img = renders[0]
torch.cuda.synchronize()
times[0] += time.time() - start
Expand All @@ -125,7 +133,7 @@ def train(
if save_imgs:
# save them as a gif with PIL
frames = [Image.fromarray(frame) for frame in frames]
out_dir = os.path.join(os.getcwd(), "renders")
out_dir = os.path.join(os.getcwd(), "results")
os.makedirs(out_dir, exist_ok=True)
frames[0].save(
f"{out_dir}/training.gif",
Expand Down Expand Up @@ -158,6 +166,7 @@ def main(
img_path: Optional[Path] = None,
iterations: int = 1000,
lr: float = 0.01,
model_type: Literal["3dgs", "2dgs"] = "3dgs",
) -> None:
if img_path:
gt_image = image_path_to_tensor(img_path)
Expand All @@ -172,6 +181,7 @@ def main(
iterations=iterations,
lr=lr,
save_imgs=save_imgs,
model_type=model_type,
)


Expand Down
1 change: 1 addition & 0 deletions examples/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,5 @@ tyro>=0.8.8
Pillow
tensorboard
pyyaml
matplotlib
git+https://github.com/rahul-goel/fused-ssim@84422e0da94c516220eb3acedb907e68809e9e01
2 changes: 1 addition & 1 deletion examples/simple_trainer.py
FantasticOven2 marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import viser
import yaml
from datasets.colmap import Dataset, Parser
from datasets.traj import generate_interpolated_path, generate_ellipse_path_z
from datasets.traj import generate_ellipse_path_z, generate_interpolated_path
from torch import Tensor
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
Expand Down
Loading
Loading