Skip to content

Commit

Permalink
move BINOMS to strategy state
Browse files Browse the repository at this point in the history
  • Loading branch information
Ruilong Li committed Jul 12, 2024
1 parent 67f439b commit bcc843c
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 19 deletions.
14 changes: 6 additions & 8 deletions gsplat/relocation.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,17 @@
import math

import torch
from torch import Tensor
from .cuda._wrapper import _make_lazy_cuda_func

N_MAX = 51
BINOMS = torch.zeros((N_MAX, N_MAX)).float().cuda()
for n in range(N_MAX):
for k in range(n + 1):
BINOMS[n, k] = math.comb(n, k)
from .cuda._wrapper import _make_lazy_cuda_func


def compute_relocation(
opacities: Tensor, # [N]
scales: Tensor, # [N, 3]
ratios: Tensor, # [N]
binoms: Tensor, # [n_max, n_max]
n_max: int,
) -> tuple[Tensor, Tensor]:
"""Compute new Gaussians from a set of old Gaussians.
Expand All @@ -38,10 +36,10 @@ def compute_relocation(
assert ratios.shape == (N,), ratios.shape
opacities = opacities.contiguous()
scales = scales.contiguous()
ratios.clamp_(min=1, max=N_MAX)
ratios.clamp_(min=1, max=n_max)
ratios = ratios.int().contiguous()

new_opacities, new_scales = _make_lazy_cuda_func("compute_relocation")(
opacities, scales, ratios, BINOMS, N_MAX
opacities, scales, ratios, binoms, n_max
)
return new_opacities, new_scales
38 changes: 28 additions & 10 deletions gsplat/strategy/mcmc.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import math
from dataclasses import dataclass
from typing import Any, Callable, DefaultDict, Dict, List, Tuple, Union

import torch
import torch.nn.functional as F
from torch import Tensor

from gsplat import quat_scale_to_covar_preci
from gsplat.relocation import compute_relocation

from .base import Strategy
from .ops import inject_noise_to_position, relocate, sample_add

Expand Down Expand Up @@ -40,11 +38,12 @@ class MCMCStrategy(Strategy):
>>> optimizers: Dict[str, torch.optim.Optimizer] = ...
>>> strategy = MCMCStrategy()
>>> strategy.check_sanity(params, optimizers)
>>> strategy_state = strategy.initialize_state()
>>> for step in range(1000):
... render_image, render_alpha, info = rasterization(...)
... loss = ...
... loss.backward()
... strategy.step_post_backward(params, optimizers, step, info, lr=1e-3)
... strategy.step_post_backward(params, optimizers, strategy_state, step, info, lr=1e-3)
"""

Expand All @@ -56,9 +55,14 @@ class MCMCStrategy(Strategy):
min_opacity: float = 0.005
verbose: bool = False

# def initialize_state(self) -> Dict[str, Any]:
# """Initialize and return the running state for this strategy."""
# return {}
def initialize_state(self) -> Dict[str, Any]:
"""Initialize and return the running state for this strategy."""
n_max = 51
binoms = torch.zeros((n_max, n_max))
for n in range(n_max):
for k in range(n + 1):
binoms[n, k] = math.comb(n, k)
return {"binoms": binoms, "n_max": n_max}

def check_sanity(
self,
Expand Down Expand Up @@ -101,7 +105,7 @@ def step_post_backward(
self,
params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict],
optimizers: Dict[str, torch.optim.Optimizer],
# state: Dict[str, Any],
state: Dict[str, Any],
step: int,
info: Dict[str, Any],
lr: float,
Expand All @@ -111,18 +115,24 @@ def step_post_backward(
Args:
lr (float): Learning rate for "means" attribute of the GS.
"""
# move to the correct device
state["binoms"] = state["binoms"].to(params["means"].device)

binoms = state["binoms"]
n_max = state["n_max"]

if (
step < self.refine_stop_iter
and step > self.refine_start_iter
and step % self.refine_every == 0
):
# teleport GSs
n_relocated_gs = self._relocate_gs(params, optimizers)
n_relocated_gs = self._relocate_gs(params, optimizers, binoms, n_max)
if self.verbose:
print(f"Step {step}: Relocated {n_relocated_gs} GSs.")

# add new GSs
n_new_gs = self._add_new_gs(params, optimizers)
n_new_gs = self._add_new_gs(params, optimizers, binoms, n_max)
if self.verbose:
print(
f"Step {step}: Added {n_new_gs} GSs. "
Expand All @@ -141,6 +151,8 @@ def _relocate_gs(
self,
params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict],
optimizers: Dict[str, torch.optim.Optimizer],
binoms: Tensor,
n_max: int,
) -> int:
opacities = torch.sigmoid(params["opacities"])
dead_mask = opacities <= self.min_opacity
Expand All @@ -151,6 +163,8 @@ def _relocate_gs(
optimizers=optimizers,
state={},
mask=dead_mask,
binoms=binoms,
n_max=n_max,
min_opacity=self.min_opacity,
)
return n_gs
Expand All @@ -160,6 +174,8 @@ def _add_new_gs(
self,
params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict],
optimizers: Dict[str, torch.optim.Optimizer],
binoms: Tensor,
n_max: int,
) -> int:
current_n_points = len(params["means"])
n_target = min(self.cap_max, int(1.05 * current_n_points))
Expand All @@ -170,6 +186,8 @@ def _add_new_gs(
optimizers=optimizers,
state={},
n=n_gs,
binoms=binoms,
n_max=n_max,
min_opacity=self.min_opacity,
)
return n_gs
8 changes: 8 additions & 0 deletions gsplat/strategy/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,8 @@ def relocate(
optimizers: Dict[str, torch.optim.Optimizer],
state: Dict[str, Tensor],
mask: Tensor,
binoms: Tensor,
n_max: int,
min_opacity: float = 0.005,
):
"""Inplace relocate some dead Gaussians to the lives ones.
Expand All @@ -228,6 +230,8 @@ def relocate(
opacities=opacities[sampled_idxs],
scales=torch.exp(params["scales"])[sampled_idxs],
ratios=torch.bincount(sampled_idxs)[sampled_idxs] + 1,
binoms=binoms,
n_max=n_max,
)
new_opacities = torch.clamp(new_opacities, max=1.0 - eps, min=min_opacity)

Expand Down Expand Up @@ -256,6 +260,8 @@ def sample_add(
optimizers: Dict[str, torch.optim.Optimizer],
state: Dict[str, Tensor],
n: int,
binoms: Tensor,
n_max: int,
min_opacity: float = 0.005,
):
opacities = torch.sigmoid(params["opacities"])
Expand All @@ -267,6 +273,8 @@ def sample_add(
opacities=opacities[sampled_idxs],
scales=torch.exp(params["scales"])[sampled_idxs],
ratios=torch.bincount(sampled_idxs)[sampled_idxs] + 1,
binoms=binoms,
n_max=n_max,
)
new_opacities = torch.clamp(new_opacities, max=1.0 - eps, min=min_opacity)

Expand Down
3 changes: 2 additions & 1 deletion tests/test_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,9 @@ def test_strategy():
# Test MCMCStrategy
strategy = MCMCStrategy(verbose=True)
strategy.check_sanity(params, optimizers)
state = strategy.initialize_state()
render_colors.mean().backward(retain_graph=True)
strategy.step_post_backward(params, optimizers, step=600, info=info, lr=1e-3)
strategy.step_post_backward(params, optimizers, state, step=600, info=info, lr=1e-3)


if __name__ == "__main__":
Expand Down

0 comments on commit bcc843c

Please sign in to comment.