diff --git a/nerfstudio/models/splatfacto.py b/nerfstudio/models/splatfacto.py index 28b8f0a1de..bd216e1ab6 100644 --- a/nerfstudio/models/splatfacto.py +++ b/nerfstudio/models/splatfacto.py @@ -25,7 +25,7 @@ import numpy as np import torch -from gsplat.strategy import DefaultStrategy +from gsplat.strategy import DefaultStrategy, MCMCStrategy try: from gsplat.rendering import rasterization @@ -33,6 +33,7 @@ print("Please install gsplat>=1.0.0") from pytorch_msssim import SSIM from torch.nn import Parameter +from typing_extensions import assert_never from nerfstudio.cameras.camera_optimizers import CameraOptimizer, CameraOptimizerConfig from nerfstudio.cameras.cameras import Cameras @@ -218,6 +219,19 @@ class SplatfactoModelConfig(ModelConfig): """Shape of the bilateral grid (X, Y, W)""" color_corrected_metrics: bool = False """If True, apply color correction to the rendered images before computing the metrics.""" + strategy: Literal["default", "mcmc"] = "default" + """The default strategy will be used if strategy is not specified. Other strategies, e.g. mcmc, can be used.""" + + cap_max: int = 1_000_000 + """Maximum number of GSs. Default to 1_000_000.""" + noise_lr: float = 5e5 + """MCMC samping noise learning rate. Default to 5e5.""" + min_opacity: float = 0.005 + """GSs with opacity below this value will be pruned. Default to 0.005.""" + verbose: bool = False + """Whether to print verbose information. Default to False.""" + max_steps: int = 30_000 + """Number of training steps""" class SplatfactoModel(Model): @@ -311,25 +325,40 @@ def populate_modules(self): grid_W=self.config.grid_shape[2], ) - # Strategy for GS densification - self.strategy = DefaultStrategy( - prune_opa=self.config.cull_alpha_thresh, - grow_grad2d=self.config.densify_grad_thresh, - grow_scale3d=self.config.densify_size_thresh, - grow_scale2d=self.config.split_screen_size, - prune_scale3d=self.config.cull_scale_thresh, - prune_scale2d=self.config.cull_screen_size, - refine_scale2d_stop_iter=self.config.stop_screen_size_at, - refine_start_iter=self.config.warmup_length, - refine_stop_iter=self.config.stop_split_at, - reset_every=self.config.reset_alpha_every * self.config.refine_every, - refine_every=self.config.refine_every, - pause_refine_after_reset=self.num_train_data + self.config.refine_every, - absgrad=self.config.use_absgrad, - revised_opacity=False, - verbose=True, - ) - self.strategy_state = self.strategy.initialize_state(scene_scale=1.0) + if self.config.strategy == "default": + # Strategy for GS densification + self.strategy = DefaultStrategy( + prune_opa=self.config.cull_alpha_thresh, + grow_grad2d=self.config.densify_grad_thresh, + grow_scale3d=self.config.densify_size_thresh, + grow_scale2d=self.config.split_screen_size, + prune_scale3d=self.config.cull_scale_thresh, + prune_scale2d=self.config.cull_screen_size, + refine_scale2d_stop_iter=self.config.stop_screen_size_at, + refine_start_iter=self.config.warmup_length, + refine_stop_iter=self.config.stop_split_at, + reset_every=self.config.reset_alpha_every * self.config.refine_every, + refine_every=self.config.refine_every, + pause_refine_after_reset=self.num_train_data + self.config.refine_every, + absgrad=self.config.use_absgrad, + revised_opacity=False, + verbose=True, + ) + self.strategy_state = self.strategy.initialize_state(scene_scale=1.0) + elif self.config.strategy == "mcmc": + self.strategy = MCMCStrategy( + cap_max=self.config.cap_max, + noise_lr=self.config.noise_lr, + refine_start_iter=self.config.warmup_length, + refine_stop_iter=self.config.stop_split_at, + refine_every=self.config.refine_every, + min_opacity=self.config.min_opacity, + verbose=self.config.verbose, + ) + self.strategy_state = self.strategy.initialize_state() + else: + raise ValueError(f"""Splatfacto does not support strategy {self.config.strategy} + Currently, the supported strategies include default and mcmc.""") @property def colors(self): @@ -421,14 +450,36 @@ def set_background(self, background_color: torch.Tensor): def step_post_backward(self, step): assert step == self.step - self.strategy.step_post_backward( - params=self.gauss_params, - optimizers=self.optimizers, - state=self.strategy_state, - step=self.step, - info=self.info, - packed=False, - ) + + schedulers = [ + torch.optim.lr_scheduler.ExponentialLR( + self.optimizers["means"], gamma=0.01 ** (1.0 / self.config.max_steps) + ) + ] + + if isinstance(self.strategy, DefaultStrategy): + self.strategy.step_post_backward( + params=self.gauss_params, + optimizers=self.optimizers, + state=self.strategy_state, + step=self.step, + info=self.info, + packed=False, + ) + elif isinstance(self.strategy, MCMCStrategy): + self.strategy.step_post_backward( + params=self.gauss_params, + optimizers=self.optimizers, + state=self.strategy_state, + step=step, + info=self.info, + lr=schedulers[0].get_last_lr()[0], + ) + else: + assert_never(self.cfg.strategy) + + for scheduler in schedulers: + scheduler.step() def get_training_callbacks( self, training_callback_attributes: TrainingCallbackAttributes @@ -612,7 +663,7 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]: render_mode=render_mode, sh_degree=sh_degree_to_use, sparse_grad=False, - absgrad=self.strategy.absgrad, + absgrad=(self.strategy.absgrad if isinstance(self.strategy, DefaultStrategy) else False), rasterize_mode=self.config.rasterize_mode, # set some threshold to disregrad small gaussians for faster rendering. # radius_clip=3.0,