diff --git a/EXPLORATION.md b/EXPLORATION.md index e8698ab75..91fd54b2f 100644 --- a/EXPLORATION.md +++ b/EXPLORATION.md @@ -26,7 +26,7 @@ | `--absgrad --grow_grad2d 2e-4` | 8m30s | 0.018s/im | 2.21 GB | 0.6251 | 20.68 | 0.587 | 0.89M | | `--absgrad --grow_grad2d 2e-4` (30k) | -- | 0.030s/im | 5.25 GB | 0.7442 | 24.12 | 0.291 | 2.62M | -Note: default args means running `CUDA_VISIBLE_DEVICES=0 python simple_trainer.py --data_dir ` with: +Note: default args means running `CUDA_VISIBLE_DEVICES=0 python simple_trainer.py default --data_dir ` with: - Garden ([Source](https://jonbarron.info/mipnerf360/)): `--result_dir results/garden` - U1 (a.k.a University 1 from [Source](https://localrf.github.io/)): `--result_dir results/u1 --data_factor 1 --grow_scale3d 0.001` diff --git a/docs/Makefile b/docs/Makefile index 41c270bb3..92dd33a1a 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -5,7 +5,7 @@ # from the environment for the first two. SPHINXOPTS ?= SPHINXBUILD ?= sphinx-build -SOURCEDIR = . +SOURCEDIR = source BUILDDIR = _build # Put it first so that "make" without argument is like "make help". @@ -17,4 +17,4 @@ help: # Catch-all target: route all unknown targets to Sphinx using the new # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). %: Makefile - @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) \ No newline at end of file + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/source/examples/colmap.rst b/docs/source/examples/colmap.rst index 13eb344f1..b6ae6eb3a 100644 --- a/docs/source/examples/colmap.rst +++ b/docs/source/examples/colmap.rst @@ -3,7 +3,7 @@ Fit a COLMAP Capture .. currentmodule:: gsplat -The :code:`examples/simple_trainer.py` script allows you train a +The :code:`examples/simple_trainer.py default` script allows you train a `3D Gaussian Splatting `_ model for novel view synthesis, on a COLMAP processed capture. This script follows the exact same logic with the `official implementation @@ -15,7 +15,7 @@ Simply run the script under `examples/`: .. code-block:: bash - CUDA_VISIBLE_DEVICES=0 python simple_trainer.py \ + CUDA_VISIBLE_DEVICES=0 python simple_trainer.py default \ --data_dir data/360_v2/garden/ --data_factor 4 \ --result_dir ./results/garden diff --git a/docs/source/examples/large_scale.rst b/docs/source/examples/large_scale.rst index 46db0bc49..3288eb512 100644 --- a/docs/source/examples/large_scale.rst +++ b/docs/source/examples/large_scale.rst @@ -35,7 +35,7 @@ The code for this example can be found under `examples/`: .. code-block:: bash # First train a 3DGS model - CUDA_VISIBLE_DEVICES=0 python simple_trainer.py \ + CUDA_VISIBLE_DEVICES=0 python simple_trainer.py default \ --data_dir data/360_v2/garden/ --data_factor 4 \ --result_dir ./results/garden diff --git a/docs/source/tests/eval.rst b/docs/source/tests/eval.rst index 486cb39b3..9f2a2fdc0 100644 --- a/docs/source/tests/eval.rst +++ b/docs/source/tests/eval.rst @@ -17,7 +17,7 @@ Evaluation | gsplat-30k (4 GPUs) | 28.91 | 0.871 | 0.135 | **2.0 GB** | **11m28s** | +---------------------+-------+-------+-------+------------------+------------+ -This repo comes with a standalone script (:code:`examples/simple_trainer.py`) that reproduces +This repo comes with a standalone script (:code:`examples/simple_trainer.py default`) that reproduces the `Gaussian Splatting `_ with exactly the same performance on PSNR, SSIM, LPIPS, and converged number of Gaussians. Powered by `gsplat`'s efficient CUDA implementation, the training takes up to diff --git a/examples/benchmarks/basic.sh b/examples/benchmarks/basic.sh index e804285dc..0c72c0c07 100644 --- a/examples/benchmarks/basic.sh +++ b/examples/benchmarks/basic.sh @@ -11,14 +11,14 @@ do echo "Running $SCENE" # train without eval - CUDA_VISIBLE_DEVICES=0 python simple_trainer.py --eval_steps -1 --disable_viewer --data_factor $DATA_FACTOR \ + CUDA_VISIBLE_DEVICES=0 python simple_trainer.py default --eval_steps -1 --disable_viewer --data_factor $DATA_FACTOR \ --data_dir data/360_v2/$SCENE/ \ --result_dir $RESULT_DIR/$SCENE/ # run eval and render for CKPT in $RESULT_DIR/$SCENE/ckpts/*; do - CUDA_VISIBLE_DEVICES=0 python simple_trainer.py --disable_viewer --data_factor $DATA_FACTOR \ + CUDA_VISIBLE_DEVICES=0 python simple_trainer.py default --disable_viewer --data_factor $DATA_FACTOR \ --data_dir data/360_v2/$SCENE/ \ --result_dir $RESULT_DIR/$SCENE/ \ --ckpt $CKPT diff --git a/examples/benchmarks/basic_4gpus.sh b/examples/benchmarks/basic_4gpus.sh index 3c3ad334e..523421609 100644 --- a/examples/benchmarks/basic_4gpus.sh +++ b/examples/benchmarks/basic_4gpus.sh @@ -11,7 +11,7 @@ do echo "Running $SCENE" # train and eval at the last step - CUDA_VISIBLE_DEVICES=0,1,2,3 python simple_trainer.py --eval_steps -1 --disable_viewer --data_factor $DATA_FACTOR \ + CUDA_VISIBLE_DEVICES=0,1,2,3 python simple_trainer.py default --eval_steps -1 --disable_viewer --data_factor $DATA_FACTOR \ # 4 GPUs is effectively 4x batch size so we scale down the steps by 4x as well. # "--packed" reduces the data transfer between GPUs, which leads to faster training. --steps_scaler 0.25 --packed \ diff --git a/examples/benchmarks/mcmc.sh b/examples/benchmarks/mcmc.sh new file mode 100644 index 000000000..057f0f4ce --- /dev/null +++ b/examples/benchmarks/mcmc.sh @@ -0,0 +1,51 @@ +RESULT_DIR=results/benchmark_mcmc_1M +CAP_MAX=1000000 + +# for SCENE in bicycle bonsai counter garden kitchen room stump; +# do +# if [ "$SCENE" = "bicycle" ] || [ "$SCENE" = "stump" ] || [ "$SCENE" = "garden" ]; then +# DATA_FACTOR=4 +# else +# DATA_FACTOR=2 +# fi + +# echo "Running $SCENE" + +# # train without eval +# CUDA_VISIBLE_DEVICES=0 python simple_trainer.py mcmc --eval_steps -1 --disable_viewer --data_factor $DATA_FACTOR \ +# --strategy.cap-max $CAP_MAX \ +# --data_dir data/360_v2/$SCENE/ \ +# --result_dir $RESULT_DIR/$SCENE/ + +# # run eval and render +# for CKPT in $RESULT_DIR/$SCENE/ckpts/*; +# do +# CUDA_VISIBLE_DEVICES=0 python simple_trainer.py mcmc --disable_viewer --data_factor $DATA_FACTOR \ +# --strategy.cap-max $CAP_MAX \ +# --data_dir data/360_v2/$SCENE/ \ +# --result_dir $RESULT_DIR/$SCENE/ \ +# --ckpt $CKPT +# done +# done + + +for SCENE in bicycle bonsai counter garden kitchen room stump; +do + echo "=== Eval Stats ===" + + for STATS in $RESULT_DIR/$SCENE/stats/val*.json; + do + echo $STATS + cat $STATS; + echo + done + + echo "=== Train Stats ===" + + for STATS in $RESULT_DIR/$SCENE/stats/train*_rank0.json; + do + echo $STATS + cat $STATS; + echo + done +done \ No newline at end of file diff --git a/examples/datasets/colmap.py b/examples/datasets/colmap.py index 22eacc2ac..2d34734e6 100644 --- a/examples/datasets/colmap.py +++ b/examples/datasets/colmap.py @@ -113,7 +113,7 @@ def __init__( if len(imdata) == 0: raise ValueError("No images found in COLMAP.") if not (type_ == 0 or type_ == 1): - print(f"Warning: COLMAP Camera is not PINHOLE. Images have distortion.") + print("Warning: COLMAP Camera is not PINHOLE. Images have distortion.") w2c_mats = np.stack(w2c_mats, axis=0) diff --git a/examples/datasets/download_dataset.py b/examples/datasets/download_dataset.py index 8366ae979..81b6fe559 100755 --- a/examples/datasets/download_dataset.py +++ b/examples/datasets/download_dataset.py @@ -9,9 +9,7 @@ import tyro # dataset names -dataset_names = Literal[ - "mipnerf360", -] +dataset_names = Literal["mipnerf360"] # dataset urls urls = {"mipnerf360": "http://storage.googleapis.com/gresearch/refraw360/360_v2.zip"} diff --git a/examples/datasets/normalize.py b/examples/datasets/normalize.py index bf0218f05..85b06456d 100644 --- a/examples/datasets/normalize.py +++ b/examples/datasets/normalize.py @@ -98,10 +98,10 @@ def align_principle_axes(point_cloud): def transform_points(matrix, points): - """Transform points using a SE(4) matrix. + """Transform points using an SE(3) matrix. Args: - matrix: 4x4 SE(4) matrix + matrix: 4x4 SE(3) matrix points: Nx3 array of points Returns: @@ -113,10 +113,10 @@ def transform_points(matrix, points): def transform_cameras(matrix, camtoworlds): - """Transform cameras using a SE(4) matrix. + """Transform cameras using an SE(3) matrix. Args: - matrix: 4x4 SE(4) matrix + matrix: 4x4 SE(3) matrix camtoworlds: Nx4x4 array of camera-to-world matrices Returns: diff --git a/examples/requirements.txt b/examples/requirements.txt index f5cf24dfe..46273f4ca 100644 --- a/examples/requirements.txt +++ b/examples/requirements.txt @@ -16,3 +16,4 @@ opencv-python tyro Pillow tensorboard +pyyaml \ No newline at end of file diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index a96db1214..2e2f52988 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -3,28 +3,30 @@ import os import time from dataclasses import dataclass, field -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Union import imageio import nerfview import numpy as np +import yaml import torch import torch.nn.functional as F import tqdm import tyro import viser -from datasets.colmap import Dataset, Parser -from datasets.traj import generate_interpolated_path +from gsplat.distributed import cli +from gsplat.rendering import rasterization +from gsplat.strategy import DefaultStrategy, MCMCStrategy from torch import Tensor from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity -from utils import AppearanceOptModule, CameraOptModule, knn, rgb_to_sh, set_random_seed +from typing_extensions import assert_never -from gsplat.distributed import cli -from gsplat.rendering import rasterization -from gsplat.strategy import DefaultStrategy +from datasets.colmap import Dataset, Parser +from datasets.traj import generate_interpolated_path +from utils import AppearanceOptModule, CameraOptModule, knn, rgb_to_sh, set_random_seed @dataclass @@ -84,34 +86,16 @@ class Config: # Far plane clipping distance far_plane: float = 1e10 - # GSs with opacity below this value will be pruned - prune_opa: float = 0.005 - # GSs with image plane gradient above this value will be split/duplicated - grow_grad2d: float = 0.0002 - # GSs with scale below this value will be duplicated. Above will be split - grow_scale3d: float = 0.01 - # GSs with scale above this value will be pruned. - prune_scale3d: float = 0.1 - - # Start refining GSs after this iteration - refine_start_iter: int = 500 - # Stop refining GSs after this iteration - refine_stop_iter: int = 15_000 - # Reset opacities every this steps - reset_every: int = 3000 - # Refine GSs every this steps - refine_every: int = 100 - + # Strategy for GS densification + strategy: Union[DefaultStrategy, MCMCStrategy] = field( + default_factory=DefaultStrategy + ) # Use packed mode for rasterization, this leads to less memory usage but slightly slower. packed: bool = False # Use sparse gradients for optimization. (experimental) sparse_grad: bool = False - # Use absolute gradient for pruning. This typically requires larger --grow_grad2d, e.g., 0.0008 or 0.0006 - absgrad: bool = False # Anti-aliasing in rasterization. Might slightly hurt quantitative metrics. antialiased: bool = False - # Whether to use revised opacity heuristic from arXiv:2404.06109 (experimental) - revised_opacity: bool = False # Use random background for training to discourage transparency random_bkgd: bool = False @@ -149,10 +133,19 @@ def adjust_steps(self, factor: float): self.save_steps = [int(i * factor) for i in self.save_steps] self.max_steps = int(self.max_steps * factor) self.sh_degree_interval = int(self.sh_degree_interval * factor) - self.refine_start_iter = int(self.refine_start_iter * factor) - self.refine_stop_iter = int(self.refine_stop_iter * factor) - self.reset_every = int(self.reset_every * factor) - self.refine_every = int(self.refine_every * factor) + + strategy = self.strategy + if isinstance(strategy, DefaultStrategy): + strategy.refine_start_iter = int(strategy.refine_start_iter * factor) + strategy.refine_stop_iter = int(strategy.refine_stop_iter * factor) + strategy.reset_every = int(strategy.reset_every * factor) + strategy.refine_every = int(strategy.refine_every * factor) + elif isinstance(strategy, MCMCStrategy): + strategy.refine_start_iter = int(strategy.refine_start_iter * factor) + strategy.refine_stop_iter = int(strategy.refine_stop_iter * factor) + strategy.refine_every = int(strategy.refine_every * factor) + else: + assert_never(strategy) def create_splats_with_optimizers( @@ -299,23 +292,16 @@ def __init__( print("Model initialized. Number of GS:", len(self.splats["means"])) # Densification Strategy - self.strategy = DefaultStrategy( - verbose=True, - scene_scale=self.scene_scale, - prune_opa=cfg.prune_opa, - grow_grad2d=cfg.grow_grad2d, - grow_scale3d=cfg.grow_scale3d, - prune_scale3d=cfg.prune_scale3d, - # refine_scale2d_stop_iter=4000, # splatfacto behavior - refine_start_iter=cfg.refine_start_iter, - refine_stop_iter=cfg.refine_stop_iter, - reset_every=cfg.reset_every, - refine_every=cfg.refine_every, - absgrad=cfg.absgrad, - revised_opacity=cfg.revised_opacity, - ) - self.strategy.check_sanity(self.splats, self.optimizers) - self.strategy_state = self.strategy.initialize_state() + self.cfg.strategy.check_sanity(self.splats, self.optimizers) + + if isinstance(self.cfg.strategy, DefaultStrategy): + self.strategy_state = self.cfg.strategy.initialize_state( + scene_scale=self.scene_scale + ) + elif isinstance(self.cfg.strategy, MCMCStrategy): + self.strategy_state = self.cfg.strategy.initialize_state() + else: + assert_never(self.cfg.strategy) self.pose_optimizers = [] if cfg.pose_opt: @@ -339,6 +325,7 @@ def __init__( self.app_optimizers = [] if cfg.app_opt: + assert feature_dim is not None self.app_module = AppearanceOptModule( len(self.trainset), feature_dim, cfg.app_embed_dim, cfg.sh_degree ).to(self.device) @@ -415,7 +402,11 @@ def rasterize_splats( width=width, height=height, packed=self.cfg.packed, - absgrad=self.cfg.absgrad, + absgrad=( + self.cfg.strategy.absgrad + if isinstance(self.cfg.strategy, DefaultStrategy) + else False + ), sparse_grad=self.cfg.sparse_grad, rasterize_mode=rasterize_mode, distributed=self.world_size > 1, @@ -431,8 +422,8 @@ def train(self): # Dump cfg. if world_rank == 0: - with open(f"{cfg.result_dir}/cfg.json", "w") as f: - json.dump(vars(cfg), f) + with open(f"{cfg.result_dir}/cfg.yml", "w") as f: + yaml.dump(vars(cfg), f) max_steps = cfg.max_steps init_step = 0 @@ -520,7 +511,7 @@ def train(self): bkgd = torch.rand(1, 3, device=device) colors = colors + bkgd * (1.0 - alphas) - self.strategy.step_pre_backward( + self.cfg.strategy.step_pre_backward( params=self.splats, optimizers=self.optimizers, state=self.strategy_state, @@ -618,14 +609,26 @@ def train(self): data, f"{self.ckpt_dir}/ckpt_{step}_rank{self.world_rank}.pt" ) - self.strategy.step_post_backward( - params=self.splats, - optimizers=self.optimizers, - state=self.strategy_state, - step=step, - info=info, - packed=cfg.packed, - ) + if isinstance(self.cfg.strategy, DefaultStrategy): + self.cfg.strategy.step_post_backward( + params=self.splats, + optimizers=self.optimizers, + state=self.strategy_state, + step=step, + info=info, + packed=cfg.packed, + ) + elif isinstance(self.cfg.strategy, MCMCStrategy): + self.cfg.strategy.step_post_backward( + params=self.splats, + optimizers=self.optimizers, + state=self.strategy_state, + step=step, + info=info, + lr=schedulers[0].get_last_lr()[0], + ) + else: + assert_never(self.cfg.strategy) # Turn Gradients into Sparse Tensor before running optimizer if cfg.sparse_grad: @@ -656,7 +659,7 @@ def train(self): scheduler.step() # eval the full set - if step in [i - 1 for i in cfg.eval_steps] or step == max_steps - 1: + if step in [i - 1 for i in cfg.eval_steps]: self.eval(step) self.render_traj(step) @@ -850,13 +853,43 @@ def main(local_rank: int, world_rank, world_size: int, cfg: Config): ```bash # Single GPU training - CUDA_VISIBLE_DEVICES=0 python simple_trainer.py + CUDA_VISIBLE_DEVICES=0 python simple_trainer.py default # Distributed training on 4 GPUs: Effectively 4x batch size so run 4x less steps. - CUDA_VISIBLE_DEVICES=0,1,2,3 python simple_trainer.py --steps_scaler 0.25 + CUDA_VISIBLE_DEVICES=0,1,2,3 python simple_trainer.py default --steps_scaler 0.25 """ - cfg = tyro.cli(Config) + # Config objects we can choose between. + # Each is a tuple of (CLI description, config object). + configs = { + "default": ( + "Gaussian splatting training using densification heuristics from the original paper.", + Config( + strategy=DefaultStrategy(verbose=True), + ), + ), + "mcmc": ( + "Gaussian splatting training using densification from the paper '3D Gaussian Splatting as Markov Chain Monte Carlo'.", + Config( + init_opa=0.5, + init_scale=0.1, + strategy=MCMCStrategy(verbose=True), + ), + ), + } + + # We're going to do some advanced tyro stuff to make the CLI nicer. + # + # (1) Build a union type that lets us choose between the two config + # objects. + subcommand_type = tyro.extras.subcommand_type_from_defaults( + defaults={k: v[1] for k, v in configs.items()}, + descriptions={k: v[0] for k, v in configs.items()}, + ) + # (2) Don't let the user override the strategy type provided by the default that they choose. + subcommand_type = tyro.conf.configure(tyro.conf.AvoidSubcommands)(subcommand_type) + + cfg = tyro.cli(subcommand_type) cfg.adjust_steps(cfg.steps_scaler) cli(main, cfg, verbose=True) diff --git a/examples/simple_trainer_mcmc.py b/examples/simple_trainer_mcmc.py deleted file mode 100644 index 483cbd66b..000000000 --- a/examples/simple_trainer_mcmc.py +++ /dev/null @@ -1,763 +0,0 @@ -import json -import math -import os -import time -from dataclasses import dataclass, field -from typing import Dict, List, Optional, Tuple - -import imageio -import nerfview -import numpy as np -import torch -import torch.nn.functional as F -import tqdm -import tyro -import viser -from datasets.colmap import Dataset, Parser -from datasets.traj import generate_interpolated_path -from simple_trainer import create_splats_with_optimizers -from torch import Tensor -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.utils.tensorboard import SummaryWriter -from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure -from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity -from utils import AppearanceOptModule, CameraOptModule, set_random_seed - -from gsplat.distributed import cli -from gsplat.rendering import rasterization -from gsplat.strategy import MCMCStrategy - - -@dataclass -class Config: - # Disable viewer - disable_viewer: bool = False - # Path to the .pt file. If provide, it will skip training and render a video - ckpt: Optional[str] = None - - # Path to the Mip-NeRF 360 dataset - data_dir: str = "data/360_v2/garden" - # Downsample factor for the dataset - data_factor: int = 4 - # Directory to save results - result_dir: str = "results/garden" - # Every N images there is a test image - test_every: int = 8 - # Random crop size for training (experimental) - patch_size: Optional[int] = None - # A global scaler that applies to the scene size related parameters - global_scale: float = 1.0 - - # Port for the viewer server - port: int = 8080 - - # Batch size for training. Learning rates are scaled automatically - batch_size: int = 1 - # A global factor to scale the number of training steps - steps_scaler: float = 1.0 - - # Number of training steps - max_steps: int = 30_000 - # Steps to evaluate the model - eval_steps: List[int] = field(default_factory=lambda: [7_000, 30_000]) - # Steps to save the model - save_steps: List[int] = field(default_factory=lambda: [7_000, 30_000]) - - # Initialization strategy - init_type: str = "sfm" - # Initial number of GSs. Ignored if using sfm - init_num_pts: int = 100_000 - # Initial extent of GSs as a multiple of the camera extent. Ignored if using sfm - init_extent: float = 3.0 - # Degree of spherical harmonics - sh_degree: int = 3 - # Turn on another SH degree every this steps - sh_degree_interval: int = 1000 - # Initial opacity of GS - init_opa: float = 0.5 - # Initial scale of GS - init_scale: float = 0.1 - # Weight for SSIM loss - ssim_lambda: float = 0.2 - - # Near plane clipping distance - near_plane: float = 0.01 - # Far plane clipping distance - far_plane: float = 1e10 - - # Maximum number of GSs. - cap_max: int = 1_000_000 - # MCMC samping noise learning rate - noise_lr = 5e5 - # Opacity regularization - opacity_reg = 0.01 - # Scale regularization - scale_reg = 0.01 - - # Start refining GSs after this iteration - refine_start_iter: int = 500 - # Stop refining GSs after this iteration - refine_stop_iter: int = 25_000 - # Refine GSs every this steps - refine_every: int = 100 - - # Use packed mode for rasterization, this leads to less memory usage but slightly slower. - packed: bool = False - # Use sparse gradients for optimization. (experimental) - sparse_grad: bool = False - # Use absolute gradient for pruning. This typically requires larger --grow_grad2d, e.g., 0.0008 or 0.0006 - absgrad: bool = False - # Anti-aliasing in rasterization. Might slightly hurt quantitative metrics. - antialiased: bool = False - - # Use random background for training to discourage transparency - random_bkgd: bool = False - - # Enable camera optimization. - pose_opt: bool = False - # Learning rate for camera optimization - pose_opt_lr: float = 1e-5 - # Regularization for camera optimization as weight decay - pose_opt_reg: float = 1e-6 - # Add noise to camera extrinsics. This is only to test the camera pose optimization. - pose_noise: float = 0.0 - - # Enable appearance optimization. (experimental) - app_opt: bool = False - # Appearance embedding dimension - app_embed_dim: int = 16 - # Learning rate for appearance optimization - app_opt_lr: float = 1e-3 - # Regularization for appearance optimization as weight decay - app_opt_reg: float = 1e-6 - - # Enable depth loss. (experimental) - depth_loss: bool = False - # Weight for depth loss - depth_lambda: float = 1e-2 - - # Dump information to tensorboard every this steps - tb_every: int = 100 - # Save training images to tensorboard - tb_save_image: bool = False - - def adjust_steps(self, factor: float): - self.eval_steps = [int(i * factor) for i in self.eval_steps] - self.save_steps = [int(i * factor) for i in self.save_steps] - self.max_steps = int(self.max_steps * factor) - self.sh_degree_interval = int(self.sh_degree_interval * factor) - self.refine_start_iter = int(self.refine_start_iter * factor) - self.refine_stop_iter = int(self.refine_stop_iter * factor) - self.refine_every = int(self.refine_every * factor) - - -class Runner: - """Engine for training and testing.""" - - def __init__( - self, local_rank: int, world_rank, world_size: int, cfg: Config - ) -> None: - set_random_seed(42 + local_rank) - - self.cfg = cfg - self.world_rank = world_rank - self.local_rank = local_rank - self.world_size = world_size - self.device = f"cuda:{local_rank}" - - # Where to dump results. - os.makedirs(cfg.result_dir, exist_ok=True) - - # Setup output directories. - self.ckpt_dir = f"{cfg.result_dir}/ckpts" - os.makedirs(self.ckpt_dir, exist_ok=True) - self.stats_dir = f"{cfg.result_dir}/stats" - os.makedirs(self.stats_dir, exist_ok=True) - self.render_dir = f"{cfg.result_dir}/renders" - os.makedirs(self.render_dir, exist_ok=True) - - # Tensorboard - self.writer = SummaryWriter(log_dir=f"{cfg.result_dir}/tb") - - # Load data: Training data should contain initial points and colors. - self.parser = Parser( - data_dir=cfg.data_dir, - factor=cfg.data_factor, - normalize=True, - test_every=cfg.test_every, - ) - self.trainset = Dataset( - self.parser, - split="train", - patch_size=cfg.patch_size, - load_depths=cfg.depth_loss, - ) - self.valset = Dataset(self.parser, split="val") - self.scene_scale = self.parser.scene_scale * 1.1 * cfg.global_scale - print("Scene scale:", self.scene_scale) - - # Model - feature_dim = 32 if cfg.app_opt else None - self.splats, self.optimizers = create_splats_with_optimizers( - self.parser, - init_type=cfg.init_type, - init_num_pts=cfg.init_num_pts, - init_extent=cfg.init_extent, - init_opacity=cfg.init_opa, - init_scale=cfg.init_scale, - scene_scale=self.scene_scale, - sh_degree=cfg.sh_degree, - sparse_grad=cfg.sparse_grad, - batch_size=cfg.batch_size, - feature_dim=feature_dim, - device=self.device, - world_rank=world_rank, - world_size=world_size, - ) - print("Model initialized. Number of GS:", len(self.splats["means"])) - - # Densification Strategy - self.strategy = MCMCStrategy( - verbose=True, - cap_max=cfg.cap_max, - noise_lr=cfg.noise_lr, - refine_start_iter=cfg.refine_start_iter, - refine_stop_iter=cfg.refine_stop_iter, - refine_every=cfg.refine_every, - ) - self.strategy.check_sanity(self.splats, self.optimizers) - self.strategy_state = self.strategy.initialize_state() - - self.pose_optimizers = [] - if cfg.pose_opt: - self.pose_adjust = CameraOptModule(len(self.trainset)).to(self.device) - self.pose_adjust.zero_init() - self.pose_optimizers = [ - torch.optim.Adam( - self.pose_adjust.parameters(), - lr=cfg.pose_opt_lr * math.sqrt(cfg.batch_size), - weight_decay=cfg.pose_opt_reg, - ) - ] - if world_size > 1: - self.pose_adjust = DDP(self.pose_adjust) - - if cfg.pose_noise > 0.0: - self.pose_perturb = CameraOptModule(len(self.trainset)).to(self.device) - self.pose_perturb.random_init(cfg.pose_noise) - if world_size > 1: - self.pose_perturb = DDP(self.pose_perturb) - - self.app_optimizers = [] - if cfg.app_opt: - self.app_module = AppearanceOptModule( - len(self.trainset), feature_dim, cfg.app_embed_dim, cfg.sh_degree - ).to(self.device) - # initialize the last layer to be zero so that the initial output is zero. - torch.nn.init.zeros_(self.app_module.color_head[-1].weight) - torch.nn.init.zeros_(self.app_module.color_head[-1].bias) - self.app_optimizers = [ - torch.optim.Adam( - self.app_module.embeds.parameters(), - lr=cfg.app_opt_lr * math.sqrt(cfg.batch_size) * 10.0, - weight_decay=cfg.app_opt_reg, - ), - torch.optim.Adam( - self.app_module.color_head.parameters(), - lr=cfg.app_opt_lr * math.sqrt(cfg.batch_size), - ), - ] - if world_size > 1: - self.app_module = DDP(self.app_module) - - # Losses & Metrics. - self.ssim = StructuralSimilarityIndexMeasure(data_range=1.0).to(self.device) - self.psnr = PeakSignalNoiseRatio(data_range=1.0).to(self.device) - self.lpips = LearnedPerceptualImagePatchSimilarity(normalize=True).to( - self.device - ) - - # Viewer - if not self.cfg.disable_viewer: - self.server = viser.ViserServer(port=cfg.port, verbose=False) - self.viewer = nerfview.Viewer( - server=self.server, - render_fn=self._viewer_render_fn, - mode="training", - ) - - def rasterize_splats( - self, - camtoworlds: Tensor, - Ks: Tensor, - width: int, - height: int, - **kwargs, - ) -> Tuple[Tensor, Tensor, Dict]: - means = self.splats["means"] # [N, 3] - # quats = F.normalize(self.splats["quats"], dim=-1) # [N, 4] - # rasterization does normalization internally - quats = self.splats["quats"] # [N, 4] - scales = torch.exp(self.splats["scales"]) # [N, 3] - opacities = torch.sigmoid(self.splats["opacities"]) # [N,] - - image_ids = kwargs.pop("image_ids", None) - if self.cfg.app_opt: - colors = self.app_module( - features=self.splats["features"], - embed_ids=image_ids, - dirs=means[None, :, :] - camtoworlds[:, None, :3, 3], - sh_degree=kwargs.pop("sh_degree", self.cfg.sh_degree), - ) - colors = colors + self.splats["colors"] - colors = torch.sigmoid(colors) - else: - colors = torch.cat([self.splats["sh0"], self.splats["shN"]], 1) # [N, K, 3] - - rasterize_mode = "antialiased" if self.cfg.antialiased else "classic" - render_colors, render_alphas, info = rasterization( - means=means, - quats=quats, - scales=scales, - opacities=opacities, - colors=colors, - viewmats=torch.linalg.inv(camtoworlds), # [C, 4, 4] - Ks=Ks, # [C, 3, 3] - width=width, - height=height, - packed=self.cfg.packed, - absgrad=self.cfg.absgrad, - sparse_grad=self.cfg.sparse_grad, - rasterize_mode=rasterize_mode, - distributed=self.world_size > 1, - **kwargs, - ) - return render_colors, render_alphas, info - - def train(self): - cfg = self.cfg - device = self.device - world_rank = self.world_rank - world_size = self.world_size - - # Dump cfg. - if world_rank == 0: - with open(f"{cfg.result_dir}/cfg.json", "w") as f: - json.dump(vars(cfg), f) - - max_steps = cfg.max_steps - init_step = 0 - - schedulers = [ - # means has a learning rate schedule, that end at 0.01 of the initial value - torch.optim.lr_scheduler.ExponentialLR( - self.optimizers["means"], gamma=0.01 ** (1.0 / max_steps) - ), - ] - if cfg.pose_opt: - # pose optimization has a learning rate schedule - schedulers.append( - torch.optim.lr_scheduler.ExponentialLR( - self.pose_optimizers[0], gamma=0.01 ** (1.0 / max_steps) - ) - ) - - trainloader = torch.utils.data.DataLoader( - self.trainset, - batch_size=cfg.batch_size, - shuffle=True, - num_workers=4, - persistent_workers=True, - pin_memory=True, - ) - trainloader_iter = iter(trainloader) - - # Training loop. - global_tic = time.time() - pbar = tqdm.tqdm(range(init_step, max_steps)) - for step in pbar: - if not cfg.disable_viewer: - while self.viewer.state.status == "paused": - time.sleep(0.01) - self.viewer.lock.acquire() - tic = time.time() - - try: - data = next(trainloader_iter) - except StopIteration: - trainloader_iter = iter(trainloader) - data = next(trainloader_iter) - - camtoworlds = camtoworlds_gt = data["camtoworld"].to(device) # [1, 4, 4] - Ks = data["K"].to(device) # [1, 3, 3] - pixels = data["image"].to(device) / 255.0 # [1, H, W, 3] - num_train_rays_per_step = ( - pixels.shape[0] * pixels.shape[1] * pixels.shape[2] - ) - image_ids = data["image_id"].to(device) - if cfg.depth_loss: - points = data["points"].to(device) # [1, M, 2] - depths_gt = data["depths"].to(device) # [1, M] - - height, width = pixels.shape[1:3] - - if cfg.pose_noise: - camtoworlds = self.pose_perturb(camtoworlds, image_ids) - - if cfg.pose_opt: - camtoworlds = self.pose_adjust(camtoworlds, image_ids) - - # sh schedule - sh_degree_to_use = min(step // cfg.sh_degree_interval, cfg.sh_degree) - - # forward - renders, alphas, info = self.rasterize_splats( - camtoworlds=camtoworlds, - Ks=Ks, - width=width, - height=height, - sh_degree=sh_degree_to_use, - near_plane=cfg.near_plane, - far_plane=cfg.far_plane, - image_ids=image_ids, - render_mode="RGB+ED" if cfg.depth_loss else "RGB", - ) - if renders.shape[-1] == 4: - colors, depths = renders[..., 0:3], renders[..., 3:4] - else: - colors, depths = renders, None - - if cfg.random_bkgd: - bkgd = torch.rand(1, 3, device=device) - colors = colors + bkgd * (1.0 - alphas) - - # loss - l1loss = F.l1_loss(colors, pixels) - ssimloss = 1.0 - self.ssim( - pixels.permute(0, 3, 1, 2), colors.permute(0, 3, 1, 2) - ) - loss = l1loss * (1.0 - cfg.ssim_lambda) + ssimloss * cfg.ssim_lambda - if cfg.depth_loss: - # query depths from depth map - points = torch.stack( - [ - points[:, :, 0] / (width - 1) * 2 - 1, - points[:, :, 1] / (height - 1) * 2 - 1, - ], - dim=-1, - ) # normalize to [-1, 1] - grid = points.unsqueeze(2) # [1, M, 1, 2] - depths = F.grid_sample( - depths.permute(0, 3, 1, 2), grid, align_corners=True - ) # [1, 1, M, 1] - depths = depths.squeeze(3).squeeze(1) # [1, M] - # calculate loss in disparity space - disp = torch.where(depths > 0.0, 1.0 / depths, torch.zeros_like(depths)) - disp_gt = 1.0 / depths_gt # [1, M] - depthloss = F.l1_loss(disp, disp_gt) * self.scene_scale - loss += depthloss * cfg.depth_lambda - - loss = ( - loss - + cfg.opacity_reg - * torch.abs(torch.sigmoid(self.splats["opacities"])).mean() - ) - loss = ( - loss - + cfg.scale_reg * torch.abs(torch.exp(self.splats["scales"])).mean() - ) - - loss.backward() - - desc = f"loss={loss.item():.3f}| " f"sh degree={sh_degree_to_use}| " - if cfg.depth_loss: - desc += f"depth loss={depthloss.item():.6f}| " - if cfg.pose_opt and cfg.pose_noise: - # monitor the pose error if we inject noise - pose_err = F.l1_loss(camtoworlds_gt, camtoworlds) - desc += f"pose err={pose_err.item():.6f}| " - pbar.set_description(desc) - - # write images (gt and render) - # if world_rank == 0 and step % 800 == 0: - # canvas = torch.cat([pixels, colors], dim=2).detach().cpu().numpy() - # canvas = canvas.reshape(-1, *canvas.shape[2:]) - # imageio.imwrite( - # f"{self.render_dir}/train_rank{self.world_rank}.png", - # (canvas * 255).astype(np.uint8), - # ) - - if world_rank == 0 and cfg.tb_every > 0 and step % cfg.tb_every == 0: - mem = torch.cuda.max_memory_allocated() / 1024**3 - self.writer.add_scalar("train/loss", loss.item(), step) - self.writer.add_scalar("train/l1loss", l1loss.item(), step) - self.writer.add_scalar("train/ssimloss", ssimloss.item(), step) - self.writer.add_scalar("train/num_GS", len(self.splats["means"]), step) - self.writer.add_scalar("train/mem", mem, step) - if cfg.depth_loss: - self.writer.add_scalar("train/depthloss", depthloss.item(), step) - if cfg.tb_save_image: - canvas = torch.cat([pixels, colors], dim=2).detach().cpu().numpy() - canvas = canvas.reshape(-1, *canvas.shape[2:]) - self.writer.add_image("train/render", canvas, step) - self.writer.flush() - - # save checkpoint before updating the model - if step in [i - 1 for i in cfg.save_steps] or step == max_steps - 1: - mem = torch.cuda.max_memory_allocated() / 1024**3 - stats = { - "mem": mem, - "ellipse_time": time.time() - global_tic, - "num_GS": len(self.splats["means"]), - } - print("Step: ", step, stats) - with open( - f"{self.stats_dir}/train_step{step:04d}_rank{self.world_rank}.json", - "w", - ) as f: - json.dump(stats, f) - data = {"step": step, "splats": self.splats.state_dict()} - if cfg.pose_opt: - if world_size > 1: - data["pose_adjust"] = self.pose_adjust.module.state_dict() - else: - data["pose_adjust"] = self.pose_adjust.state_dict() - if cfg.app_opt: - if world_size > 1: - data["app_module"] = self.app_module.module.state_dict() - else: - data["app_module"] = self.app_module.state_dict() - torch.save( - data, f"{self.ckpt_dir}/ckpt_{step}_rank{self.world_rank}.pt" - ) - - self.strategy.step_post_backward( - params=self.splats, - optimizers=self.optimizers, - state=self.strategy_state, - step=step, - info=info, - lr=schedulers[0].get_last_lr()[0], - ) - - # Turn Gradients into Sparse Tensor before running optimizer - if cfg.sparse_grad: - assert cfg.packed, "Sparse gradients only work with packed mode." - gaussian_ids = info["gaussian_ids"] - for k in self.splats.keys(): - grad = self.splats[k].grad - if grad is None or grad.is_sparse: - continue - self.splats[k].grad = torch.sparse_coo_tensor( - indices=gaussian_ids[None], # [1, nnz] - values=grad[gaussian_ids], # [nnz, ...] - size=self.splats[k].size(), # [N, ...] - is_coalesced=len(Ks) == 1, - ) - - # optimize - for optimizer in self.optimizers.values(): - optimizer.step() - optimizer.zero_grad(set_to_none=True) - for optimizer in self.pose_optimizers: - optimizer.step() - optimizer.zero_grad(set_to_none=True) - for optimizer in self.app_optimizers: - optimizer.step() - optimizer.zero_grad(set_to_none=True) - for scheduler in schedulers: - scheduler.step() - - # eval the full set - if step in [i - 1 for i in cfg.eval_steps] or step == max_steps - 1: - self.eval(step) - self.render_traj(step) - - if not cfg.disable_viewer: - self.viewer.lock.release() - num_train_steps_per_sec = 1.0 / (time.time() - tic) - num_train_rays_per_sec = ( - num_train_rays_per_step * num_train_steps_per_sec - ) - # Update the viewer state. - self.viewer.state.num_train_rays_per_sec = num_train_rays_per_sec - # Update the scene. - self.viewer.update(step, num_train_rays_per_step) - - @torch.no_grad() - def eval(self, step: int): - """Entry for evaluation.""" - print("Running evaluation...") - cfg = self.cfg - device = self.device - world_rank = self.world_rank - world_size = self.world_size - - valloader = torch.utils.data.DataLoader( - self.valset, batch_size=1, shuffle=False, num_workers=1 - ) - ellipse_time = 0 - metrics = {"psnr": [], "ssim": [], "lpips": []} - for i, data in enumerate(valloader): - camtoworlds = data["camtoworld"].to(device) - Ks = data["K"].to(device) - pixels = data["image"].to(device) / 255.0 - height, width = pixels.shape[1:3] - - torch.cuda.synchronize() - tic = time.time() - colors, _, _ = self.rasterize_splats( - camtoworlds=camtoworlds, - Ks=Ks, - width=width, - height=height, - sh_degree=cfg.sh_degree, - near_plane=cfg.near_plane, - far_plane=cfg.far_plane, - ) # [1, H, W, 3] - colors = torch.clamp(colors, 0.0, 1.0) - torch.cuda.synchronize() - ellipse_time += time.time() - tic - - if world_rank == 0: - # write images - canvas = torch.cat([pixels, colors], dim=2).squeeze(0).cpu().numpy() - imageio.imwrite( - f"{self.render_dir}/val_{i:04d}.png", - (canvas * 255).astype(np.uint8), - ) - - pixels = pixels.permute(0, 3, 1, 2) # [1, 3, H, W] - colors = colors.permute(0, 3, 1, 2) # [1, 3, H, W] - metrics["psnr"].append(self.psnr(colors, pixels)) - metrics["ssim"].append(self.ssim(colors, pixels)) - metrics["lpips"].append(self.lpips(colors, pixels)) - - if world_rank == 0: - ellipse_time /= len(valloader) - - psnr = torch.stack(metrics["psnr"]).mean() - ssim = torch.stack(metrics["ssim"]).mean() - lpips = torch.stack(metrics["lpips"]).mean() - print( - f"PSNR: {psnr.item():.3f}, SSIM: {ssim.item():.4f}, LPIPS: {lpips.item():.3f} " - f"Time: {ellipse_time:.3f}s/image " - f"Number of GS: {len(self.splats['means'])}" - ) - # save stats as json - stats = { - "psnr": psnr.item(), - "ssim": ssim.item(), - "lpips": lpips.item(), - "ellipse_time": ellipse_time, - "num_GS": len(self.splats["means"]), - } - with open(f"{self.stats_dir}/val_step{step:04d}.json", "w") as f: - json.dump(stats, f) - # save stats to tensorboard - for k, v in stats.items(): - self.writer.add_scalar(f"val/{k}", v, step) - self.writer.flush() - - @torch.no_grad() - def render_traj(self, step: int): - """Entry for trajectory rendering.""" - print("Running trajectory rendering...") - cfg = self.cfg - device = self.device - - camtoworlds = self.parser.camtoworlds[5:-5] - camtoworlds = generate_interpolated_path(camtoworlds, 1) # [N, 3, 4] - camtoworlds = np.concatenate( - [ - camtoworlds, - np.repeat(np.array([[[0.0, 0.0, 0.0, 1.0]]]), len(camtoworlds), axis=0), - ], - axis=1, - ) # [N, 4, 4] - - camtoworlds = torch.from_numpy(camtoworlds).float().to(device) - K = torch.from_numpy(list(self.parser.Ks_dict.values())[0]).float().to(device) - width, height = list(self.parser.imsize_dict.values())[0] - - canvas_all = [] - for i in tqdm.trange(len(camtoworlds), desc="Rendering trajectory"): - renders, _, _ = self.rasterize_splats( - camtoworlds=camtoworlds[i : i + 1], - Ks=K[None], - width=width, - height=height, - sh_degree=cfg.sh_degree, - near_plane=cfg.near_plane, - far_plane=cfg.far_plane, - render_mode="RGB+ED", - ) # [1, H, W, 4] - colors = torch.clamp(renders[0, ..., 0:3], 0.0, 1.0) # [H, W, 3] - depths = renders[0, ..., 3:4] # [H, W, 1] - depths = (depths - depths.min()) / (depths.max() - depths.min()) - - # write images - canvas = torch.cat( - [colors, depths.repeat(1, 1, 3)], dim=0 if width > height else 1 - ) - canvas = (canvas.cpu().numpy() * 255).astype(np.uint8) - canvas_all.append(canvas) - - # save to video - video_dir = f"{cfg.result_dir}/videos" - os.makedirs(video_dir, exist_ok=True) - writer = imageio.get_writer(f"{video_dir}/traj_{step}.mp4", fps=30) - for canvas in canvas_all: - writer.append_data(canvas) - writer.close() - print(f"Video saved to {video_dir}/traj_{step}.mp4") - - @torch.no_grad() - def _viewer_render_fn( - self, camera_state: nerfview.CameraState, img_wh: Tuple[int, int] - ): - """Callable function for the viewer.""" - W, H = img_wh - c2w = camera_state.c2w - K = camera_state.get_K(img_wh) - c2w = torch.from_numpy(c2w).float().to(self.device) - K = torch.from_numpy(K).float().to(self.device) - - render_colors, _, _ = self.rasterize_splats( - camtoworlds=c2w[None], - Ks=K[None], - width=W, - height=H, - sh_degree=self.cfg.sh_degree, # active all SH degrees - radius_clip=3.0, # skip GSs that have small image radius (in pixels) - ) # [1, H, W, 3] - return render_colors[0].cpu().numpy() - - -def main(local_rank: int, world_rank, world_size: int, cfg: Config): - if world_size > 1 and not cfg.disable_viewer: - cfg.disable_viewer = True - if world_rank == 0: - print("Viewer is disabled in distributed training.") - - runner = Runner(local_rank, world_rank, world_size, cfg) - - if cfg.ckpt is not None: - # run eval only - ckpt = torch.load(cfg.ckpt, map_location=runner.device) - for k in runner.splats.keys(): - runner.splats[k].data = ckpt["splats"][k] - runner.eval(step=ckpt["step"]) - runner.render_traj(step=ckpt["step"]) - else: - runner.train() - - if not cfg.disable_viewer: - print("Viewer running... Ctrl+C to exit.") - time.sleep(1000000) - - -if __name__ == "__main__": - cfg = tyro.cli(Config) - cfg.adjust_steps(cfg.steps_scaler) - cli(main, cfg, verbose=True) diff --git a/gsplat/cuda_legacy/_wrapper.py b/gsplat/cuda_legacy/_wrapper.py index c33da69fa..d5d0f2c64 100644 --- a/gsplat/cuda_legacy/_wrapper.py +++ b/gsplat/cuda_legacy/_wrapper.py @@ -91,7 +91,7 @@ def get_tile_bin_edges( def compute_cov2d_bounds( - cov2d: Float[Tensor, "batch 3"] + cov2d: Float[Tensor, "batch 3"], ) -> Tuple[Float[Tensor, "batch_conics 3"], Float[Tensor, "batch_radii 1"]]: """Computes bounds of 2D covariance matrix @@ -113,7 +113,7 @@ def compute_cov2d_bounds( def compute_cumulative_intersects( - num_tiles_hit: Float[Tensor, "batch 1"] + num_tiles_hit: Float[Tensor, "batch 1"], ) -> Tuple[int, Float[Tensor, "batch 1"]]: """Computes cumulative intersections of gaussians. This is useful for creating unique gaussian IDs and for sorting. diff --git a/gsplat/rendering.py b/gsplat/rendering.py index a18363beb..9fedf4b8f 100644 --- a/gsplat/rendering.py +++ b/gsplat/rendering.py @@ -1,5 +1,5 @@ import math -from typing import Dict, List, Optional, Tuple +from typing import Dict, Optional, Tuple import torch import torch.distributed diff --git a/gsplat/strategy/default.py b/gsplat/strategy/default.py index 8ce957d0f..c35b8bbd3 100644 --- a/gsplat/strategy/default.py +++ b/gsplat/strategy/default.py @@ -30,8 +30,6 @@ class DefaultStrategy(Strategy): with `absgrad=True` as well so that the absolute gradients are computed. Args: - scene_scale (float): The scale of the scene for calibrating the scale-related - logic. Default is 1.0. prune_opa (float): GSs with opacity below this value will be pruned. Default is 0.005. grow_grad2d (float): GSs with image plane gradient above this value will be split/duplicated. Default is 0.0002. @@ -71,7 +69,6 @@ class DefaultStrategy(Strategy): """ - scene_scale: float = 1.0 prune_opa: float = 0.005 grow_grad2d: float = 0.0002 grow_scale3d: float = 0.01 @@ -87,7 +84,7 @@ class DefaultStrategy(Strategy): revised_opacity: bool = False verbose: bool = False - def initialize_state(self) -> Dict[str, Any]: + def initialize_state(self, scene_scale: float) -> Dict[str, Any]: """Initialize and return the running state for this strategy. The returned state should be passed to the `step_pre_backward()` and @@ -98,7 +95,7 @@ def initialize_state(self) -> Dict[str, Any]: # - grad2d: running accum of the norm of the image plane gradients for each GS. # - count: running accum of how many time each GS is visible. # - radii: the radii of the GSs (normalized by the image resolution). - state = {"grad2d": None, "count": None} + state = {"grad2d": None, "count": None, "scene_scale": scene_scale} if self.refine_scale2d_stop_iter > 0: state["radii"] = None return state @@ -255,7 +252,7 @@ def _grow_gs( is_grad_high = grads > self.grow_grad2d is_small = ( torch.exp(params["scales"]).max(dim=-1).values - <= self.grow_scale3d * self.scene_scale + <= self.grow_scale3d * state["scene_scale"] ) is_dupli = is_grad_high & is_small n_dupli = is_dupli.sum().item() @@ -301,7 +298,7 @@ def _prune_gs( if step > self.reset_every: is_too_big = ( torch.exp(params["scales"]).max(dim=-1).values - > self.prune_scale3d * self.scene_scale + > self.prune_scale3d * state["scene_scale"] ) # The official code also implements sreen-size pruning but # it's actually not being used due to a bug: diff --git a/gsplat/strategy/ops.py b/gsplat/strategy/ops.py index 4ea774465..61e37d4be 100644 --- a/gsplat/strategy/ops.py +++ b/gsplat/strategy/ops.py @@ -75,7 +75,8 @@ def optimizer_fn(key: str, v: Tensor) -> Tensor: _update_param_with_optimizer(param_fn, optimizer_fn, params, optimizers) # update the extra running state for k, v in state.items(): - state[k] = torch.cat((v, v[sel])) + if isinstance(v, torch.Tensor): + state[k] = torch.cat((v, v[sel])) @torch.no_grad() @@ -132,9 +133,10 @@ def optimizer_fn(key: str, v: Tensor) -> Tensor: _update_param_with_optimizer(param_fn, optimizer_fn, params, optimizers) # update the extra running state for k, v in state.items(): - repeats = [2] + [1] * (v.dim() - 1) - v_new = v[sel].repeat(repeats) - state[k] = torch.cat((v[rest], v_new)) + if isinstance(v, torch.Tensor): + repeats = [2] + [1] * (v.dim() - 1) + v_new = v[sel].repeat(repeats) + state[k] = torch.cat((v[rest], v_new)) @torch.no_grad() @@ -163,7 +165,8 @@ def optimizer_fn(key: str, v: Tensor) -> Tensor: _update_param_with_optimizer(param_fn, optimizer_fn, params, optimizers) # update the extra running state for k, v in state.items(): - state[k] = v[sel] + if isinstance(v, torch.Tensor): + state[k] = v[sel] @torch.no_grad() @@ -248,7 +251,8 @@ def optimizer_fn(key: str, v: Tensor) -> Tensor: _update_param_with_optimizer(param_fn, optimizer_fn, params, optimizers) # update the extra running state for k, v in state.items(): - v[sampled_idxs] = 0 + if isinstance(v, torch.Tensor): + v[sampled_idxs] = 0 @torch.no_grad() @@ -290,7 +294,8 @@ def optimizer_fn(key: str, v: Tensor) -> Tensor: # update the extra running state for k, v in state.items(): v_new = torch.zeros((len(sampled_idxs), *v.shape[1:]), device=v.device) - state[k] = torch.cat((v, v_new)) + if isinstance(v, torch.Tensor): + state[k] = torch.cat((v, v_new)) @torch.no_grad() diff --git a/setup.py b/setup.py index 7b36e4dc8..3ea2dfe33 100644 --- a/setup.py +++ b/setup.py @@ -98,7 +98,7 @@ def get_extensions(): current_dir = pathlib.Path(__file__).parent.resolve() glm_path = os.path.join(current_dir, "gsplat", "cuda", "csrc", "third_party", "glm") extension_v1 = CUDAExtension( - f"gsplat.csrc_legacy", + "gsplat.csrc_legacy", sources_v1, include_dirs=[extensions_dir_v2, glm_path], # glm lives in v2. define_macros=define_macros, @@ -107,7 +107,7 @@ def get_extensions(): extra_link_args=extra_link_args, ) extension_v2 = CUDAExtension( - f"gsplat.csrc", + "gsplat.csrc", sources_v2, include_dirs=[extensions_dir_v2, glm_path], # glm lives in v2. define_macros=define_macros, diff --git a/tests/test_basic.py b/tests/test_basic.py index f8864915e..d150360c4 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -445,7 +445,6 @@ def test_rasterize_to_pixels(test_data, channels: int): fully_fused_projection, isect_offset_encode, isect_tiles, - persp_proj, quat_scale_to_covar_preci, rasterize_to_pixels, ) diff --git a/tests/test_rasterization.py b/tests/test_rasterization.py index 57ea44bee..1aad8738b 100644 --- a/tests/test_rasterization.py +++ b/tests/test_rasterization.py @@ -6,7 +6,6 @@ ``` """ -import math from typing import Optional import pytest diff --git a/tests/test_strategy.py b/tests/test_strategy.py index 3cd634df7..5f432a9ba 100644 --- a/tests/test_strategy.py +++ b/tests/test_strategy.py @@ -6,12 +6,8 @@ ``` """ -import math -from typing import Optional - import pytest import torch -import torch.nn.functional as F device = torch.device("cuda:0")