diff --git a/gsplat/relocation.py b/gsplat/relocation.py index da1a80d53..e532446cf 100644 --- a/gsplat/relocation.py +++ b/gsplat/relocation.py @@ -1,19 +1,17 @@ import math + import torch from torch import Tensor -from .cuda._wrapper import _make_lazy_cuda_func -N_MAX = 51 -BINOMS = torch.zeros((N_MAX, N_MAX)).float().cuda() -for n in range(N_MAX): - for k in range(n + 1): - BINOMS[n, k] = math.comb(n, k) +from .cuda._wrapper import _make_lazy_cuda_func def compute_relocation( opacities: Tensor, # [N] scales: Tensor, # [N, 3] ratios: Tensor, # [N] + binoms: Tensor, # [n_max, n_max] + n_max: int, ) -> tuple[Tensor, Tensor]: """Compute new Gaussians from a set of old Gaussians. @@ -38,10 +36,10 @@ def compute_relocation( assert ratios.shape == (N,), ratios.shape opacities = opacities.contiguous() scales = scales.contiguous() - ratios.clamp_(min=1, max=N_MAX) + ratios.clamp_(min=1, max=n_max) ratios = ratios.int().contiguous() new_opacities, new_scales = _make_lazy_cuda_func("compute_relocation")( - opacities, scales, ratios, BINOMS, N_MAX + opacities, scales, ratios, binoms, n_max ) return new_opacities, new_scales diff --git a/gsplat/strategy/mcmc.py b/gsplat/strategy/mcmc.py index 854806bb6..9b5abc3bb 100644 --- a/gsplat/strategy/mcmc.py +++ b/gsplat/strategy/mcmc.py @@ -1,3 +1,4 @@ +import math from dataclasses import dataclass from typing import Any, Callable, DefaultDict, Dict, List, Tuple, Union @@ -5,9 +6,6 @@ import torch.nn.functional as F from torch import Tensor -from gsplat import quat_scale_to_covar_preci -from gsplat.relocation import compute_relocation - from .base import Strategy from .ops import inject_noise_to_position, relocate, sample_add @@ -40,11 +38,12 @@ class MCMCStrategy(Strategy): >>> optimizers: Dict[str, torch.optim.Optimizer] = ... >>> strategy = MCMCStrategy() >>> strategy.check_sanity(params, optimizers) + >>> strategy_state = strategy.initialize_state() >>> for step in range(1000): ... render_image, render_alpha, info = rasterization(...) ... loss = ... ... loss.backward() - ... strategy.step_post_backward(params, optimizers, step, info, lr=1e-3) + ... strategy.step_post_backward(params, optimizers, strategy_state, step, info, lr=1e-3) """ @@ -56,9 +55,14 @@ class MCMCStrategy(Strategy): min_opacity: float = 0.005 verbose: bool = False - # def initialize_state(self) -> Dict[str, Any]: - # """Initialize and return the running state for this strategy.""" - # return {} + def initialize_state(self) -> Dict[str, Any]: + """Initialize and return the running state for this strategy.""" + n_max = 51 + binoms = torch.zeros((n_max, n_max)) + for n in range(n_max): + for k in range(n + 1): + binoms[n, k] = math.comb(n, k) + return {"binoms": binoms, "n_max": n_max} def check_sanity( self, @@ -101,7 +105,7 @@ def step_post_backward( self, params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict], optimizers: Dict[str, torch.optim.Optimizer], - # state: Dict[str, Any], + state: Dict[str, Any], step: int, info: Dict[str, Any], lr: float, @@ -111,18 +115,24 @@ def step_post_backward( Args: lr (float): Learning rate for "means" attribute of the GS. """ + # move to the correct device + state["binoms"] = state["binoms"].to(params["means"].device) + + binoms = state["binoms"] + n_max = state["n_max"] + if ( step < self.refine_stop_iter and step > self.refine_start_iter and step % self.refine_every == 0 ): # teleport GSs - n_relocated_gs = self._relocate_gs(params, optimizers) + n_relocated_gs = self._relocate_gs(params, optimizers, binoms, n_max) if self.verbose: print(f"Step {step}: Relocated {n_relocated_gs} GSs.") # add new GSs - n_new_gs = self._add_new_gs(params, optimizers) + n_new_gs = self._add_new_gs(params, optimizers, binoms, n_max) if self.verbose: print( f"Step {step}: Added {n_new_gs} GSs. " @@ -141,6 +151,8 @@ def _relocate_gs( self, params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict], optimizers: Dict[str, torch.optim.Optimizer], + binoms: Tensor, + n_max: int, ) -> int: opacities = torch.sigmoid(params["opacities"]) dead_mask = opacities <= self.min_opacity @@ -151,6 +163,8 @@ def _relocate_gs( optimizers=optimizers, state={}, mask=dead_mask, + binoms=binoms, + n_max=n_max, min_opacity=self.min_opacity, ) return n_gs @@ -160,6 +174,8 @@ def _add_new_gs( self, params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict], optimizers: Dict[str, torch.optim.Optimizer], + binoms: Tensor, + n_max: int, ) -> int: current_n_points = len(params["means"]) n_target = min(self.cap_max, int(1.05 * current_n_points)) @@ -170,6 +186,8 @@ def _add_new_gs( optimizers=optimizers, state={}, n=n_gs, + binoms=binoms, + n_max=n_max, min_opacity=self.min_opacity, ) return n_gs diff --git a/gsplat/strategy/ops.py b/gsplat/strategy/ops.py index 58d8c9922..cbb3705c3 100644 --- a/gsplat/strategy/ops.py +++ b/gsplat/strategy/ops.py @@ -204,6 +204,8 @@ def relocate( optimizers: Dict[str, torch.optim.Optimizer], state: Dict[str, Tensor], mask: Tensor, + binoms: Tensor, + n_max: int, min_opacity: float = 0.005, ): """Inplace relocate some dead Gaussians to the lives ones. @@ -228,6 +230,8 @@ def relocate( opacities=opacities[sampled_idxs], scales=torch.exp(params["scales"])[sampled_idxs], ratios=torch.bincount(sampled_idxs)[sampled_idxs] + 1, + binoms=binoms, + n_max=n_max, ) new_opacities = torch.clamp(new_opacities, max=1.0 - eps, min=min_opacity) @@ -256,6 +260,8 @@ def sample_add( optimizers: Dict[str, torch.optim.Optimizer], state: Dict[str, Tensor], n: int, + binoms: Tensor, + n_max: int, min_opacity: float = 0.005, ): opacities = torch.sigmoid(params["opacities"]) @@ -267,6 +273,8 @@ def sample_add( opacities=opacities[sampled_idxs], scales=torch.exp(params["scales"])[sampled_idxs], ratios=torch.bincount(sampled_idxs)[sampled_idxs] + 1, + binoms=binoms, + n_max=n_max, ) new_opacities = torch.clamp(new_opacities, max=1.0 - eps, min=min_opacity) diff --git a/tests/test_strategy.py b/tests/test_strategy.py index f652aeb0e..3cd634df7 100644 --- a/tests/test_strategy.py +++ b/tests/test_strategy.py @@ -61,8 +61,9 @@ def test_strategy(): # Test MCMCStrategy strategy = MCMCStrategy(verbose=True) strategy.check_sanity(params, optimizers) + state = strategy.initialize_state() render_colors.mean().backward(retain_graph=True) - strategy.step_post_backward(params, optimizers, step=600, info=info, lr=1e-3) + strategy.step_post_backward(params, optimizers, state, step=600, info=info, lr=1e-3) if __name__ == "__main__":