Skip to content

Commit

Permalink
Consolidate simple_trainer.py and simple_trainer_mcmc.py (#325)
Browse files Browse the repository at this point in the history
* Consolidate simple_trainer.py and simple_trainer_mcmc.py

* ruff

* black

* black

* reset glm

* Revert docstring changes, fix docs Makefile

* simple_trainer.py -> simple_trainer.py default

* minor fixes

* format

* mcmc script

---------

Co-authored-by: Ruilong Li <[email protected]>
  • Loading branch information
brentyi and Ruilong Li authored Aug 10, 2024
1 parent 8bd343f commit fc1a3ca
Show file tree
Hide file tree
Showing 22 changed files with 188 additions and 872 deletions.
2 changes: 1 addition & 1 deletion EXPLORATION.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
| `--absgrad --grow_grad2d 2e-4` | 8m30s | 0.018s/im | 2.21 GB | 0.6251 | 20.68 | 0.587 | 0.89M |
| `--absgrad --grow_grad2d 2e-4` (30k) | -- | 0.030s/im | 5.25 GB | 0.7442 | 24.12 | 0.291 | 2.62M |

Note: default args means running `CUDA_VISIBLE_DEVICES=0 python simple_trainer.py --data_dir <DATA_DIR>` with:
Note: default args means running `CUDA_VISIBLE_DEVICES=0 python simple_trainer.py default --data_dir <DATA_DIR>` with:

- Garden ([Source](https://jonbarron.info/mipnerf360/)): `--result_dir results/garden`
- U1 (a.k.a University 1 from [Source](https://localrf.github.io/)): `--result_dir results/u1 --data_factor 1 --grow_scale3d 0.001`
Expand Down
4 changes: 2 additions & 2 deletions docs/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# from the environment for the first two.
SPHINXOPTS ?=
SPHINXBUILD ?= sphinx-build
SOURCEDIR = .
SOURCEDIR = source
BUILDDIR = _build

# Put it first so that "make" without argument is like "make help".
Expand All @@ -17,4 +17,4 @@ help:
# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
%: Makefile
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
4 changes: 2 additions & 2 deletions docs/source/examples/colmap.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ Fit a COLMAP Capture

.. currentmodule:: gsplat

The :code:`examples/simple_trainer.py` script allows you train a
The :code:`examples/simple_trainer.py default` script allows you train a
`3D Gaussian Splatting <https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/>`_
model for novel view synthesis, on a COLMAP processed capture. This script follows the
exact same logic with the `official implementation
Expand All @@ -15,7 +15,7 @@ Simply run the script under `examples/`:

.. code-block:: bash
CUDA_VISIBLE_DEVICES=0 python simple_trainer.py \
CUDA_VISIBLE_DEVICES=0 python simple_trainer.py default \
--data_dir data/360_v2/garden/ --data_factor 4 \
--result_dir ./results/garden
Expand Down
2 changes: 1 addition & 1 deletion docs/source/examples/large_scale.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ The code for this example can be found under `examples/`:
.. code-block:: bash
# First train a 3DGS model
CUDA_VISIBLE_DEVICES=0 python simple_trainer.py \
CUDA_VISIBLE_DEVICES=0 python simple_trainer.py default \
--data_dir data/360_v2/garden/ --data_factor 4 \
--result_dir ./results/garden
Expand Down
2 changes: 1 addition & 1 deletion docs/source/tests/eval.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ Evaluation
| gsplat-30k (4 GPUs) | 28.91 | 0.871 | 0.135 | **2.0 GB** | **11m28s** |
+---------------------+-------+-------+-------+------------------+------------+

This repo comes with a standalone script (:code:`examples/simple_trainer.py`) that reproduces
This repo comes with a standalone script (:code:`examples/simple_trainer.py default`) that reproduces
the `Gaussian Splatting <https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/>`_ with
exactly the same performance on PSNR, SSIM, LPIPS, and converged number of Gaussians.
Powered by `gsplat`'s efficient CUDA implementation, the training takes up to
Expand Down
4 changes: 2 additions & 2 deletions examples/benchmarks/basic.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@ do
echo "Running $SCENE"

# train without eval
CUDA_VISIBLE_DEVICES=0 python simple_trainer.py --eval_steps -1 --disable_viewer --data_factor $DATA_FACTOR \
CUDA_VISIBLE_DEVICES=0 python simple_trainer.py default --eval_steps -1 --disable_viewer --data_factor $DATA_FACTOR \
--data_dir data/360_v2/$SCENE/ \
--result_dir $RESULT_DIR/$SCENE/

# run eval and render
for CKPT in $RESULT_DIR/$SCENE/ckpts/*;
do
CUDA_VISIBLE_DEVICES=0 python simple_trainer.py --disable_viewer --data_factor $DATA_FACTOR \
CUDA_VISIBLE_DEVICES=0 python simple_trainer.py default --disable_viewer --data_factor $DATA_FACTOR \
--data_dir data/360_v2/$SCENE/ \
--result_dir $RESULT_DIR/$SCENE/ \
--ckpt $CKPT
Expand Down
2 changes: 1 addition & 1 deletion examples/benchmarks/basic_4gpus.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ do
echo "Running $SCENE"

# train and eval at the last step
CUDA_VISIBLE_DEVICES=0,1,2,3 python simple_trainer.py --eval_steps -1 --disable_viewer --data_factor $DATA_FACTOR \
CUDA_VISIBLE_DEVICES=0,1,2,3 python simple_trainer.py default --eval_steps -1 --disable_viewer --data_factor $DATA_FACTOR \
# 4 GPUs is effectively 4x batch size so we scale down the steps by 4x as well.
# "--packed" reduces the data transfer between GPUs, which leads to faster training.
--steps_scaler 0.25 --packed \
Expand Down
51 changes: 51 additions & 0 deletions examples/benchmarks/mcmc.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
RESULT_DIR=results/benchmark_mcmc_1M
CAP_MAX=1000000

# for SCENE in bicycle bonsai counter garden kitchen room stump;
# do
# if [ "$SCENE" = "bicycle" ] || [ "$SCENE" = "stump" ] || [ "$SCENE" = "garden" ]; then
# DATA_FACTOR=4
# else
# DATA_FACTOR=2
# fi

# echo "Running $SCENE"

# # train without eval
# CUDA_VISIBLE_DEVICES=0 python simple_trainer.py mcmc --eval_steps -1 --disable_viewer --data_factor $DATA_FACTOR \
# --strategy.cap-max $CAP_MAX \
# --data_dir data/360_v2/$SCENE/ \
# --result_dir $RESULT_DIR/$SCENE/

# # run eval and render
# for CKPT in $RESULT_DIR/$SCENE/ckpts/*;
# do
# CUDA_VISIBLE_DEVICES=0 python simple_trainer.py mcmc --disable_viewer --data_factor $DATA_FACTOR \
# --strategy.cap-max $CAP_MAX \
# --data_dir data/360_v2/$SCENE/ \
# --result_dir $RESULT_DIR/$SCENE/ \
# --ckpt $CKPT
# done
# done


for SCENE in bicycle bonsai counter garden kitchen room stump;
do
echo "=== Eval Stats ==="

for STATS in $RESULT_DIR/$SCENE/stats/val*.json;
do
echo $STATS
cat $STATS;
echo
done

echo "=== Train Stats ==="

for STATS in $RESULT_DIR/$SCENE/stats/train*_rank0.json;
do
echo $STATS
cat $STATS;
echo
done
done
2 changes: 1 addition & 1 deletion examples/datasets/colmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def __init__(
if len(imdata) == 0:
raise ValueError("No images found in COLMAP.")
if not (type_ == 0 or type_ == 1):
print(f"Warning: COLMAP Camera is not PINHOLE. Images have distortion.")
print("Warning: COLMAP Camera is not PINHOLE. Images have distortion.")

w2c_mats = np.stack(w2c_mats, axis=0)

Expand Down
4 changes: 1 addition & 3 deletions examples/datasets/download_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,7 @@
import tyro

# dataset names
dataset_names = Literal[
"mipnerf360",
]
dataset_names = Literal["mipnerf360"]

# dataset urls
urls = {"mipnerf360": "http://storage.googleapis.com/gresearch/refraw360/360_v2.zip"}
Expand Down
8 changes: 4 additions & 4 deletions examples/datasets/normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,10 @@ def align_principle_axes(point_cloud):


def transform_points(matrix, points):
"""Transform points using a SE(4) matrix.
"""Transform points using an SE(3) matrix.
Args:
matrix: 4x4 SE(4) matrix
matrix: 4x4 SE(3) matrix
points: Nx3 array of points
Returns:
Expand All @@ -113,10 +113,10 @@ def transform_points(matrix, points):


def transform_cameras(matrix, camtoworlds):
"""Transform cameras using a SE(4) matrix.
"""Transform cameras using an SE(3) matrix.
Args:
matrix: 4x4 SE(4) matrix
matrix: 4x4 SE(3) matrix
camtoworlds: Nx4x4 array of camera-to-world matrices
Returns:
Expand Down
1 change: 1 addition & 0 deletions examples/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ opencv-python
tyro
Pillow
tensorboard
pyyaml
Loading

0 comments on commit fc1a3ca

Please sign in to comment.