Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support MCMC Strategy in Splatfacto #3436

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 80 additions & 29 deletions nerfstudio/models/splatfacto.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,15 @@

import numpy as np
import torch
from gsplat.strategy import DefaultStrategy
from gsplat.strategy import DefaultStrategy, MCMCStrategy

try:
from gsplat.rendering import rasterization
except ImportError:
print("Please install gsplat>=1.0.0")
from pytorch_msssim import SSIM
from torch.nn import Parameter
from typing_extensions import assert_never

from nerfstudio.cameras.camera_optimizers import CameraOptimizer, CameraOptimizerConfig
from nerfstudio.cameras.cameras import Cameras
Expand Down Expand Up @@ -218,6 +219,19 @@ class SplatfactoModelConfig(ModelConfig):
"""Shape of the bilateral grid (X, Y, W)"""
color_corrected_metrics: bool = False
"""If True, apply color correction to the rendered images before computing the metrics."""
strategy: Literal["default", "mcmc"] = "default"
"""The default strategy will be used if strategy is not specified. Other strategies, e.g. mcmc, can be used."""

cap_max: int = 1_000_000
"""Maximum number of GSs. Default to 1_000_000."""
noise_lr: float = 5e5
"""MCMC samping noise learning rate. Default to 5e5."""
min_opacity: float = 0.005
"""GSs with opacity below this value will be pruned. Default to 0.005."""
verbose: bool = False
"""Whether to print verbose information. Default to False."""
max_steps: int = 30_000
"""Number of training steps"""


class SplatfactoModel(Model):
Expand Down Expand Up @@ -311,25 +325,40 @@ def populate_modules(self):
grid_W=self.config.grid_shape[2],
)

# Strategy for GS densification
self.strategy = DefaultStrategy(
prune_opa=self.config.cull_alpha_thresh,
grow_grad2d=self.config.densify_grad_thresh,
grow_scale3d=self.config.densify_size_thresh,
grow_scale2d=self.config.split_screen_size,
prune_scale3d=self.config.cull_scale_thresh,
prune_scale2d=self.config.cull_screen_size,
refine_scale2d_stop_iter=self.config.stop_screen_size_at,
refine_start_iter=self.config.warmup_length,
refine_stop_iter=self.config.stop_split_at,
reset_every=self.config.reset_alpha_every * self.config.refine_every,
refine_every=self.config.refine_every,
pause_refine_after_reset=self.num_train_data + self.config.refine_every,
absgrad=self.config.use_absgrad,
revised_opacity=False,
verbose=True,
)
self.strategy_state = self.strategy.initialize_state(scene_scale=1.0)
if self.config.strategy == "default":
# Strategy for GS densification
self.strategy = DefaultStrategy(
prune_opa=self.config.cull_alpha_thresh,
grow_grad2d=self.config.densify_grad_thresh,
grow_scale3d=self.config.densify_size_thresh,
grow_scale2d=self.config.split_screen_size,
prune_scale3d=self.config.cull_scale_thresh,
prune_scale2d=self.config.cull_screen_size,
refine_scale2d_stop_iter=self.config.stop_screen_size_at,
refine_start_iter=self.config.warmup_length,
refine_stop_iter=self.config.stop_split_at,
reset_every=self.config.reset_alpha_every * self.config.refine_every,
refine_every=self.config.refine_every,
pause_refine_after_reset=self.num_train_data + self.config.refine_every,
absgrad=self.config.use_absgrad,
revised_opacity=False,
verbose=True,
)
self.strategy_state = self.strategy.initialize_state(scene_scale=1.0)
elif self.config.strategy == "mcmc":
self.strategy = MCMCStrategy(
cap_max=self.config.cap_max,
noise_lr=self.config.noise_lr,
refine_start_iter=self.config.warmup_length,
refine_stop_iter=self.config.stop_split_at,
refine_every=self.config.refine_every,
min_opacity=self.config.min_opacity,
verbose=self.config.verbose,
)
self.strategy_state = self.strategy.initialize_state()
else:
raise ValueError(f"""Splatfacto does not support strategy {self.config.strategy}
Currently, the supported strategies include default and mcmc.""")

@property
def colors(self):
Expand Down Expand Up @@ -421,14 +450,36 @@ def set_background(self, background_color: torch.Tensor):

def step_post_backward(self, step):
assert step == self.step
self.strategy.step_post_backward(
params=self.gauss_params,
optimizers=self.optimizers,
state=self.strategy_state,
step=self.step,
info=self.info,
packed=False,
)

schedulers = [
torch.optim.lr_scheduler.ExponentialLR(
self.optimizers["means"], gamma=0.01 ** (1.0 / self.config.max_steps)
)
]

if isinstance(self.strategy, DefaultStrategy):
self.strategy.step_post_backward(
params=self.gauss_params,
optimizers=self.optimizers,
state=self.strategy_state,
step=self.step,
info=self.info,
packed=False,
)
elif isinstance(self.strategy, MCMCStrategy):
self.strategy.step_post_backward(
params=self.gauss_params,
optimizers=self.optimizers,
state=self.strategy_state,
step=step,
info=self.info,
lr=schedulers[0].get_last_lr()[0],
)
else:
assert_never(self.cfg.strategy)

for scheduler in schedulers:
scheduler.step()

def get_training_callbacks(
self, training_callback_attributes: TrainingCallbackAttributes
Expand Down Expand Up @@ -612,7 +663,7 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]:
render_mode=render_mode,
sh_degree=sh_degree_to_use,
sparse_grad=False,
absgrad=self.strategy.absgrad,
absgrad=(self.strategy.absgrad if isinstance(self.strategy, DefaultStrategy) else False),
rasterize_mode=self.config.rasterize_mode,
# set some threshold to disregrad small gaussians for faster rendering.
# radius_clip=3.0,
Expand Down
Loading