Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Finalize MCMC strategy and some tiny fix #3548

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 62 additions & 1 deletion nerfstudio/configs/method_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,7 +696,68 @@
),
},
"bilateral_grid": {
"optimizer": AdamOptimizerConfig(lr=5e-3, eps=1e-15),
"optimizer": AdamOptimizerConfig(lr=2e-3, eps=1e-15),
"scheduler": ExponentialDecaySchedulerConfig(
lr_final=1e-4, max_steps=30000, warmup_steps=1000, lr_pre_warmup=0
),
},
},
viewer=ViewerConfig(num_rays_per_chunk=1 << 15),
vis="viewer",
)

method_configs["splatfacto-mcmc"] = TrainerConfig(
method_name="splatfacto",
steps_per_eval_image=100,
steps_per_eval_batch=0,
steps_per_save=2000,
steps_per_eval_all_images=1000,
max_num_iterations=30000,
mixed_precision=False,
pipeline=VanillaPipelineConfig(
datamanager=FullImageDatamanagerConfig(
dataparser=NerfstudioDataParserConfig(load_3D_points=True),
cache_images_type="uint8",
),
model=SplatfactoModelConfig(
strategy="mcmc",
cull_alpha_thresh=0.005,
stop_split_at=25000,
),
),
optimizers={
"means": {
"optimizer": AdamOptimizerConfig(lr=1.6e-4, eps=1e-15),
"scheduler": ExponentialDecaySchedulerConfig(
lr_final=1.6e-6,
max_steps=30000,
),
},
"features_dc": {
"optimizer": AdamOptimizerConfig(lr=0.0025, eps=1e-15),
"scheduler": None,
},
"features_rest": {
"optimizer": AdamOptimizerConfig(lr=0.0025 / 20, eps=1e-15),
"scheduler": None,
},
"opacities": {
"optimizer": AdamOptimizerConfig(lr=0.05, eps=1e-15),
"scheduler": None,
},
"scales": {
"optimizer": AdamOptimizerConfig(lr=0.005, eps=1e-15),
"scheduler": None,
},
"quats": {"optimizer": AdamOptimizerConfig(lr=0.001, eps=1e-15), "scheduler": None},
"camera_opt": {
"optimizer": AdamOptimizerConfig(lr=1e-4, eps=1e-15),
"scheduler": ExponentialDecaySchedulerConfig(
lr_final=5e-7, max_steps=30000, warmup_steps=1000, lr_pre_warmup=0
),
},
"bilateral_grid": {
"optimizer": AdamOptimizerConfig(lr=2e-3, eps=1e-15),
"scheduler": ExponentialDecaySchedulerConfig(
lr_final=1e-4, max_steps=30000, warmup_steps=1000, lr_pre_warmup=0
),
Expand Down
106 changes: 78 additions & 28 deletions nerfstudio/models/splatfacto.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from typing import Dict, List, Literal, Optional, Tuple, Type, Union

import torch
from gsplat.strategy import DefaultStrategy
from gsplat.strategy import DefaultStrategy, MCMCStrategy

try:
from gsplat.rendering import rasterization
Expand Down Expand Up @@ -156,6 +156,16 @@ class SplatfactoModelConfig(ModelConfig):
"""Shape of the bilateral grid (X, Y, W)"""
color_corrected_metrics: bool = False
"""If True, apply color correction to the rendered images before computing the metrics."""
strategy: Literal["default", "mcmc"] = "default"
"""The default strategy will be used if strategy is not specified. Other strategies, e.g. mcmc, can be used."""
max_gs_num: int = 1_000_000
"""Maximum number of GSs. Default to 1_000_000."""
noise_lr: float = 5e5
"""MCMC samping noise learning rate. Default to 5e5."""
mcmc_opacity_reg: float = 0.01
"""Regularization term for opacity in MCMC strategy. Only enabled when using MCMC strategy"""
mcmc_scale_reg: float = 0.01
"""Regularization term for scale in MCMC strategy. Only enabled when using MCMC strategy"""


class SplatfactoModel(Model):
Expand Down Expand Up @@ -249,24 +259,40 @@ def populate_modules(self):
)

# Strategy for GS densification
self.strategy = DefaultStrategy(
prune_opa=self.config.cull_alpha_thresh,
grow_grad2d=self.config.densify_grad_thresh,
grow_scale3d=self.config.densify_size_thresh,
grow_scale2d=self.config.split_screen_size,
prune_scale3d=self.config.cull_scale_thresh,
prune_scale2d=self.config.cull_screen_size,
refine_scale2d_stop_iter=self.config.stop_screen_size_at,
refine_start_iter=self.config.warmup_length,
refine_stop_iter=self.config.stop_split_at,
reset_every=self.config.reset_alpha_every * self.config.refine_every,
refine_every=self.config.refine_every,
pause_refine_after_reset=self.num_train_data + self.config.refine_every,
absgrad=self.config.use_absgrad,
revised_opacity=False,
verbose=True,
)
self.strategy_state = self.strategy.initialize_state(scene_scale=1.0)
if self.config.strategy == "default":
# Strategy for GS densification
self.strategy = DefaultStrategy(
prune_opa=self.config.cull_alpha_thresh,
grow_grad2d=self.config.densify_grad_thresh,
grow_scale3d=self.config.densify_size_thresh,
grow_scale2d=self.config.split_screen_size,
prune_scale3d=self.config.cull_scale_thresh,
prune_scale2d=self.config.cull_screen_size,
refine_scale2d_stop_iter=self.config.stop_screen_size_at,
refine_start_iter=self.config.warmup_length,
refine_stop_iter=self.config.stop_split_at,
reset_every=self.config.reset_alpha_every * self.config.refine_every,
refine_every=self.config.refine_every,
pause_refine_after_reset=self.num_train_data + self.config.refine_every,
absgrad=self.config.use_absgrad,
revised_opacity=False,
verbose=True,
)
self.strategy_state = self.strategy.initialize_state(scene_scale=1.0)
elif self.config.strategy == "mcmc":
self.strategy = MCMCStrategy(
cap_max=self.config.max_gs_num,
noise_lr=self.config.noise_lr,
refine_start_iter=self.config.warmup_length,
refine_stop_iter=self.config.stop_split_at,
refine_every=self.config.refine_every,
min_opacity=self.config.cull_alpha_thresh,
verbose=False,
)
self.strategy_state = self.strategy.initialize_state()
else:
raise ValueError(f"""Splatfacto does not support strategy {self.config.strategy}
Currently, the supported strategies include default and mcmc.""")

@property
def colors(self):
Expand Down Expand Up @@ -338,14 +364,26 @@ def set_background(self, background_color: torch.Tensor):

def step_post_backward(self, step):
assert step == self.step
self.strategy.step_post_backward(
params=self.gauss_params,
optimizers=self.optimizers,
state=self.strategy_state,
step=self.step,
info=self.info,
packed=False,
)
if isinstance(self.strategy, DefaultStrategy):
self.strategy.step_post_backward(
params=self.gauss_params,
optimizers=self.optimizers,
state=self.strategy_state,
step=self.step,
info=self.info,
packed=False,
)
elif isinstance(self.strategy, MCMCStrategy):
self.strategy.step_post_backward(
params=self.gauss_params,
optimizers=self.optimizers,
state=self.strategy_state,
step=step,
info=self.info,
lr=self.schedulers["means"].get_last_lr()[0], # the learning rate for the "means" attribute of the GS
)
else:
raise ValueError(f"Unknown strategy {self.strategy}")

def get_training_callbacks(
self, training_callback_attributes: TrainingCallbackAttributes
Expand All @@ -369,6 +407,7 @@ def get_training_callbacks(
def step_cb(self, optimizers: Optimizers, step):
self.step = step
self.optimizers = optimizers.optimizers
self.schedulers = optimizers.schedulers

def get_gaussian_param_groups(self) -> Dict[str, List[Parameter]]:
# Here we explicitly use the means, scales as parameters so that the user can override this function and
Expand Down Expand Up @@ -529,7 +568,7 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]:
render_mode=render_mode,
sh_degree=sh_degree_to_use,
sparse_grad=False,
absgrad=self.strategy.absgrad,
absgrad=self.strategy.absgrad if isinstance(self.strategy, DefaultStrategy) else False,
rasterize_mode=self.config.rasterize_mode,
# set some threshold to disregrad small gaussians for faster rendering.
# radius_clip=3.0,
Expand Down Expand Up @@ -651,6 +690,17 @@ def get_loss_dict(self, outputs, batch, metrics_dict=None) -> Dict[str, torch.Te
"scale_reg": scale_reg,
}

# Losses for mcmc
if self.config.strategy == "mcmc":
if self.config.mcmc_opacity_reg > 0.0:
mcmc_opacity_reg = (
self.config.mcmc_opacity_reg * torch.abs(torch.sigmoid(self.gauss_params["opacities"])).mean()
)
loss_dict["mcmc_opacity_reg"] = mcmc_opacity_reg
if self.config.mcmc_scale_reg > 0.0:
mcmc_scale_reg = self.config.mcmc_scale_reg * torch.abs(torch.exp(self.gauss_params["scales"])).mean()
loss_dict["mcmc_scale_reg"] = mcmc_scale_reg

if self.training:
# Add loss from camera optimizer
self.camera_optimizer.get_loss_dict(loss_dict)
Expand Down
1 change: 1 addition & 0 deletions tests/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
"neus-facto",
"splatfacto",
"splatfacto-big",
"splatfacto-mcmc",
]


Expand Down
Loading