Skip to content

Commit

Permalink
implement anchor growing algorithm
Browse files Browse the repository at this point in the history
  • Loading branch information
MrNeRF committed Sep 19, 2024
1 parent 7e71887 commit f8fbb8f
Show file tree
Hide file tree
Showing 4 changed files with 280 additions and 69 deletions.
83 changes: 46 additions & 37 deletions examples/simple_trainer_scaffold.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@

from gsplat.compression import PngCompression
from gsplat.distributed import cli
from gsplat.rendering import rasterization, filter_visible_gaussians
from gsplat.rendering import rasterization, view_to_visible_anchors
from gsplat.strategy import ScaffoldStrategy


Expand Down Expand Up @@ -162,6 +162,7 @@ def adjust_steps(self, factor: float):
strategy = self.strategy
strategy.refine_start_iter = int(strategy.refine_start_iter * factor)
strategy.refine_stop_iter = int(strategy.refine_stop_iter * factor)
strategy.voxel_size = self.voxel_size
# strategy.reset_every = int(strategy.reset_every * factor)
# strategy.refine_every = int(strategy.refine_every * factor)

Expand Down Expand Up @@ -199,7 +200,7 @@ def create_splats_with_optimizers(
N = points.shape[0]
quats = torch.rand((N, 4)) # [N, 4]

features = torch.zeros((N, strategy.mean_feat_dim))
features = torch.zeros((N, strategy.feat_dim))
offsets = torch.zeros((N, strategy.n_feat_offsets, 3))

params = [
Expand Down Expand Up @@ -334,7 +335,7 @@ def __init__(
if cfg.app_opt:
assert feature_dim is not None
self.app_module = AppearanceOptModule(
len(self.trainset), feature_dim, cfg.app_embed_dim, cfg.sh_degree
len(self.trainset), feature_dim, cfg.app_embed_dim, None
).to(self.device)
# initialize the last layer to be zero so that the initial output is zero.
torch.nn.init.zeros_(self.app_module.color_head[-1].weight)
Expand Down Expand Up @@ -394,7 +395,7 @@ def __init__(
mode="training",
)

def get_visibility_mask(
def get_visible_anchor_mask(
self,
camtoworlds: Tensor,
Ks: Tensor,
Expand All @@ -408,7 +409,7 @@ def get_visibility_mask(
quats = self.splats["quats"] # [N, 4]
scales = torch.exp(self.splats["scales"])[:, :3] # [N, 3]

visibility_mask = filter_visible_gaussians(
visible_anchor_mask = view_to_visible_anchors(
means=anchors,
quats=quats,
scales=scales,
Expand All @@ -420,18 +421,18 @@ def get_visibility_mask(
rasterize_mode=rasterize_mode,
)

return visibility_mask
return visible_anchor_mask

def get_neural_gaussians(self, cam_pos, selection=None):
def get_neural_gaussians(self, cam_pos, visible_anchor_mask=None):

# If no visibility mask is provided, we select all anchors including their offsets
if selection is None:
selection = torch.ones(self.splats["anchors"].shape[0], dtype=torch.bool, device=self.device)
if visible_anchor_mask is None:
visible_anchor_mask = torch.ones(self.splats["anchors"].shape[0], dtype=torch.bool, device=self.device)

selected_features = self.splats["features"][selection] # [M, c]
selected_anchors = self.splats["anchors"][selection] # [M, 3]
selected_offsets = self.splats["offsets"][selection] # [M, k, 3]
selected_scales = torch.exp(self.splats["scales"][selection]) # [M, 6]
selected_features = self.splats["features"][visible_anchor_mask] # [M, c]
selected_anchors = self.splats["anchors"][visible_anchor_mask] # [M, 3]
selected_offsets = self.splats["offsets"][visible_anchor_mask] # [M, k, 3]
selected_scales = torch.exp(self.splats["scales"][visible_anchor_mask]) # [M, 6]

# See formula (5) in Scaffold-GS
view_dir = selected_anchors - cam_pos # [M, 3]
Expand All @@ -445,7 +446,7 @@ def get_neural_gaussians(self, cam_pos, selection=None):
# Apply MLPs (they output per-offset features concatenated along the last dimension)
neural_opacity = self.cfg.strategy.opacities_mlp(feature_view_dir) # [M, k*1]
neural_opacity = neural_opacity.view(-1, 1) # [M*k, 1]
pos_opacity_mask = (neural_opacity > 0.0).view(-1) # [M*k]
neural_selection_mask = (neural_opacity > 0.0).view(-1) # [M*k]

# Get color and reshape
neural_colors = self.cfg.strategy.colors_mlp(feature_view_dir) # [M, k*3]
Expand All @@ -461,22 +462,22 @@ def get_neural_gaussians(self, cam_pos, selection=None):
anchors_repeated = selected_anchors.unsqueeze(1).repeat(1, k, 1).view(-1, 3) # [M*k, 3]

# Apply positive opacity mask
selected_opacity = neural_opacity[pos_opacity_mask].squeeze(-1) # [m]
selected_colors = neural_colors[pos_opacity_mask] # [m, 3]
selected_scale_rot = neural_scale_rot[pos_opacity_mask] # [m, 7]
selected_offsets = selected_offsets[pos_opacity_mask] # [m, 3]
scales_repeated = scales_repeated[pos_opacity_mask] # [m, 6]
anchors_repeated = anchors_repeated[pos_opacity_mask] # [m, 3]
selected_opacity = neural_opacity[neural_selection_mask].squeeze(-1) # [m]
selected_colors = neural_colors[neural_selection_mask] # [m, 3]
selected_scale_rot = neural_scale_rot[neural_selection_mask] # [m, 7]
selected_offsets = selected_offsets[neural_selection_mask] # [m, 3]
scales_repeated = scales_repeated[neural_selection_mask] # [m, 6]
anchors_repeated = anchors_repeated[neural_selection_mask] # [m, 3]

# Compute scales and rotations
scales = scales_repeated[:, 3:] * torch.sigmoid(selected_scale_rot[:, :3]) # [m, 3]
rotation = torch.nn.functional.normalize(selected_scale_rot[:, 3:7]) # [m, 4]

# Compute offsets and anchors
offsets = selected_offsets * scales_repeated[:, :3] # [m, 3]
anchors = anchors_repeated + offsets # [m, 3]
means = anchors_repeated + offsets # [m, 3]

return anchors, selected_colors, selected_opacity, scales, rotation, neural_opacity, pos_opacity_mask
return means, selected_colors, selected_opacity, scales, rotation, neural_opacity, neural_selection_mask

def rasterize_splats(
self,
Expand All @@ -487,23 +488,24 @@ def rasterize_splats(
**kwargs,
) -> Tuple[Tensor, Tensor, Dict, Tensor]:

visibility_mask = self.get_visibility_mask(camtoworlds=camtoworlds,
# We select only the visible anchors for faster inference
visible_anchor_mask = self.get_visible_anchor_mask(camtoworlds=camtoworlds,
Ks=Ks,
width=width,
height=height,
packed=self.cfg.packed,
rasterize_mode = "antialiased" if self.cfg.antialiased else "classic",
)

anchors, color_mlp, opacities, scales, quats, neural_opacity, selection_mask = self.get_neural_gaussians(camtoworlds[:, :3, 3], selection=visibility_mask)
# Get all the gaussians per voxel spawned from the anchors
means, color_mlp, opacities, scales, quats, neural_opacity, neural_selection_mask = self.get_neural_gaussians(camtoworlds[:, :3, 3], visible_anchor_mask=visible_anchor_mask)

image_ids = kwargs.pop("image_ids", None)
if self.cfg.app_opt:
colors = self.app_module(
features=self.splats["features"],
embed_ids=image_ids,
dirs=anchors[None, :, :] - camtoworlds[:, None, :3, 3],
sh_degree=kwargs.pop("sh_degree", self.cfg.sh_degree),
dirs=means[None, :, :] - camtoworlds[:, None, :3, 3],
)
colors = colors + color_mlp
colors = torch.sigmoid(colors)
Expand All @@ -512,7 +514,7 @@ def rasterize_splats(

rasterize_mode = "antialiased" if self.cfg.antialiased else "classic"
render_colors, render_alphas, info = rasterization(
means=anchors,
means=means,
quats=quats,
scales=scales,
opacities=opacities,
Expand All @@ -528,6 +530,13 @@ def rasterize_splats(
distributed=self.world_size > 1,
**kwargs,
)
info.update(
{
"visible_anchor_mask": visible_anchor_mask,
"neural_selection_mask": neural_selection_mask,
"neural_opacities": neural_opacity,
}
)
return render_colors, render_alphas, info, scales

def train(self):
Expand Down Expand Up @@ -755,14 +764,14 @@ def train(self):
)

# For now no post steps
# self.cfg.strategy.step_post_backward(
# params=self.splats,
# optimizers=self.optimizers,
# state=self.strategy_state,
# step=step,
# info=info,
# packed=cfg.packed,
# )
self.cfg.strategy.step_post_backward(
params=self.splats,
optimizers=self.optimizers,
state=self.strategy_state,
step=step,
info=info,
packed=cfg.packed,
)

# Turn Gradients into Sparse Tensor before running optimizer
if cfg.sparse_grad:
Expand Down Expand Up @@ -842,7 +851,7 @@ def eval(self, step: int, stage: str = "val"):
Ks=Ks,
width=width,
height=height,
sh_degree=cfg.sh_degree,
sh_degree=None,
near_plane=cfg.near_plane,
far_plane=cfg.far_plane,
) # [1, H, W, 3]
Expand Down Expand Up @@ -946,7 +955,7 @@ def render_traj(self, step: int):
Ks=Ks,
width=width,
height=height,
sh_degree=cfg.sh_degree,
sh_degree=None,
near_plane=cfg.near_plane,
far_plane=cfg.far_plane,
render_mode="RGB+ED",
Expand Down
2 changes: 1 addition & 1 deletion gsplat/rendering.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,7 +582,7 @@ def reshape_view(C: int, world_view: torch.Tensor, N_world: list) -> torch.Tenso
return render_colors, render_alphas, meta


def filter_visible_gaussians(
def view_to_visible_anchors(
means: Tensor, # [N, 3]
quats: Tensor, # [N, 4]
scales: Tensor, # [N, 3]
Expand Down
91 changes: 91 additions & 0 deletions gsplat/strategy/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch
import torch.nn.functional as F
from torch import Tensor
from torch_scatter import scatter_max

from gsplat import quat_scale_to_covar_preci
from gsplat.relocation import compute_relocation
Expand Down Expand Up @@ -361,3 +362,93 @@ def op_sigmoid(x, k=100, x0=0.995):
)
noise = torch.einsum("bij,bj->bi", covars, noise)
params["means"].add_(noise)


@torch.no_grad()
def grow_anchors(
params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict],
optimizers: Dict[str, torch.optim.Optimizer],
state: Dict[str, Tensor],
anchors: torch.Tensor,
gradient_mask: torch.Tensor,
remove_duplicates_mask: torch.Tensor,
inv_idx: torch.Tensor,
voxel_size: float,
n_feat_offsets: int,
feat_dim: int,
):
"""Inplace add new Gaussians (anchors) to the parameters.
Args:
params: A dictionary of parameters.
optimizers: A dictionary of optimizers, each corresponding to a parameter.
state: A dictionary of extra state tensors.
anchors: Positions of new anchors to be added.
gradient_mask: A mask to select gradients.
remove_duplicates_mask: A mask to remove duplicates.
inv_idx: Indices for inverse mapping.
voxel_size: The size of the voxel.
n_feat_offsets: Number of feature offsets.
feat_dim: Dimension of features.
"""
device = anchors.device
num_new = anchors.size(0)

# Scale anchors
anchors = anchors * voxel_size # [N_new, 3]

# Initialize new parameters
log_voxel_size = torch.log(torch.tensor(voxel_size, device=device))
scaling = log_voxel_size.expand(num_new, anchors.size(1) * 2) # [N_new, 6]

rotation = torch.zeros((num_new, 4), device=device)
rotation[:, 0] = 1.0 # Identity quaternion

# Prepare new features
existing_features = params["features"] # [N_existing, feat_dim]
repeated_features = existing_features.repeat_interleave(n_feat_offsets, dim=0) # [N_existing * n_feat_offsets, feat_dim]

selected_features = repeated_features[gradient_mask] # [N_selected, feat_dim]

# Use inverse_indices to aggregate features
scattered_features, _ = scatter_max(
selected_features,
inv_idx.unsqueeze(1).expand(-1, feat_dim),
dim=0
)
feat = scattered_features[remove_duplicates_mask] # [N_new, feat_dim]

# Initialize new offsets
offsets = torch.zeros((num_new, n_feat_offsets, 3), device=device) # [N_new, n_feat_offsets, 3]

def param_fn(name: str, p: Tensor) -> Tensor:
if name == "anchors":
p_new = torch.cat([p, anchors], dim=0)
elif name == "scales":
p_new = torch.cat([p, scaling], dim=0)
elif name == "quats":
p_new = torch.cat([p, rotation], dim=0)
elif name == "features":
p_new = torch.cat([p, feat], dim=0)
elif name == "offsets":
p_new = torch.cat([p, offsets], dim=0)
else:
raise ValueError(f"Parameter '{name}' not recognized.")
return torch.nn.Parameter(p_new)

def optimizer_fn(key: str, v: Tensor) -> Tensor:
# Extend optimizer state tensors with zeros
zeros = torch.zeros((num_new, *v.shape[1:]), device=device)
v_new = torch.cat([v, zeros], dim=0)
return v_new

# Update parameters and optimizer states
names = ["anchors", "scales", "quats", "features", "offsets"]
_update_param_with_optimizer(param_fn, optimizer_fn, params, optimizers, names)

# Update the extra running state
for k, v in state.items():
if isinstance(v, torch.Tensor):
zeros = torch.zeros((num_new * n_feat_offsets, *v.shape[1:]), device=device)
state[k] = torch.cat([v, zeros], dim=0)

Loading

0 comments on commit f8fbb8f

Please sign in to comment.