Skip to content

Commit

Permalink
implement the anchor pruning mechanism
Browse files Browse the repository at this point in the history
  • Loading branch information
MrNeRF committed Sep 19, 2024
1 parent f8fbb8f commit d7140a4
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 43 deletions.
11 changes: 8 additions & 3 deletions gsplat/strategy/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,13 +181,16 @@ def remove(
optimizers: Dict[str, torch.optim.Optimizer],
state: Dict[str, Tensor],
mask: Tensor,
names: Union[List[str], None] = None,
):
"""Inplace remove the Gaussian with the given mask.
Args:
params: A dictionary of parameters.
optimizers: A dictionary of optimizers, each corresponding to a parameter.
state: A dictionary of extra state tensors.
mask: A boolean mask to remove the Gaussians.
names: A list of key names to update. If None, update all. Default: None.
"""
sel = torch.where(~mask)[0]

Expand All @@ -198,7 +201,7 @@ def optimizer_fn(key: str, v: Tensor) -> Tensor:
return v[sel]

# update the parameters and the state in the optimizers
_update_param_with_optimizer(param_fn, optimizer_fn, params, optimizers)
_update_param_with_optimizer(param_fn, optimizer_fn, params, optimizers, names)
# update the extra running state
for k, v in state.items():
if isinstance(v, torch.Tensor):
Expand Down Expand Up @@ -360,7 +363,6 @@ def op_sigmoid(x, k=100, x0=0.995):
* (op_sigmoid(1 - opacities)).unsqueeze(-1)
* scaler
)
noise = torch.einsum("bij,bj->bi", covars, noise)
params["means"].add_(noise)


Expand Down Expand Up @@ -449,6 +451,9 @@ def optimizer_fn(key: str, v: Tensor) -> Tensor:
# Update the extra running state
for k, v in state.items():
if isinstance(v, torch.Tensor):
zeros = torch.zeros((num_new * n_feat_offsets, *v.shape[1:]), device=device)
if k == "anchor_count" or k == "anchor_opacity":
zeros = torch.zeros((num_new, *v.shape[1:]), device=device)
else:
zeros = torch.zeros((num_new * n_feat_offsets, *v.shape[1:]), device=device)
state[k] = torch.cat([v, zeros], dim=0)

129 changes: 89 additions & 40 deletions gsplat/strategy/scaffold.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,10 @@ def initialize_state(self, scene_scale: float = 1.0) -> Dict[str, Any]:
# - grad2d: running accum of the norm of the image plane gradients for each GS.
# - count: running accum of how many time each GS is visible.
# - radii: the radii of the GSs (normalized by the image resolution).
# - radii: the radii of the GSs (normalized by the image resolution).
state = {"grad2d": None,
"count": None,
"anchor_count": None,
"anchor_opacity": None,
"scene_scale": scene_scale}
if self.refine_scale2d_stop_iter > 0:
state["radii"] = None
Expand Down Expand Up @@ -204,83 +205,131 @@ def step_post_backward(
)

# prune GSs
# n_prune = self._prune_gs(params, optimizers, state, step)
# if self.verbose:
# print(
# f"Step {step}: {n_prune} GSs pruned. "
# f"Now having {len(params['anchors'])} GSs."
# )
is_prune = (state["anchor_opacity"] < self.prune_opa * state["anchor_count"]).squeeze()
mask = state["anchor_opacity"] > self.prune_opa
print(mask.sum().item())

n_prune = is_prune.sum().item()
if n_prune > 0:
names = ["anchors", "scales", "quats", "features", "offsets"]
remove(params=params, optimizers=optimizers, state=state, mask=is_prune, names=names)

if self.verbose:
print(
f"Step {step}: {n_prune} GSs pruned. "
f"Now having {len(params['anchors'])} GSs."
)

# reset running stats
state["grad2d"].zero_()
state["count"].zero_()
state["anchor_count"].zero_()
state["anchor_opacity"].zero_()
device = params["anchors"].device
n_gaussian = params["anchors"].shape[0]
state["grad2d"] = torch.zeros(n_gaussian * self.n_feat_offsets, device=device)
state["count"] = torch.zeros(n_gaussian * self.n_feat_offsets, device=device)
state["anchor_count"] = torch.zeros(n_gaussian, device=device)
state["anchor_opacity"] = torch.zeros(n_gaussian, device=device)

if self.refine_scale2d_stop_iter > 0:
state["radii"].zero_()
torch.cuda.empty_cache()

def _update_state(
self,
params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict],
state: Dict[str, Any],
info: Dict[str, Any],
packed: bool = False,
self,
params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict],
state: Dict[str, Any],
info: Dict[str, Any],
packed: bool = False,
):
for key in ["width", "height", "n_cameras", "radii", "gaussian_ids"]:
# Ensure required keys are present
required_keys = ["width", "height", "n_cameras", "radii", "gaussian_ids"]
for key in required_keys:
assert key in info, f"{key} is required but missing."

# normalize grads to [-1, 1] screen space
# Normalize gradients to [-1, 1] screen space
scale_factors = torch.tensor(
[info["width"] / 2.0 * info["n_cameras"], info["height"] / 2.0 * info["n_cameras"]],
device=info["means2d"].device,
)

if self.absgrad:
grads = info["means2d"].absgrad.clone()
grads = info["means2d"].absgrad.detach() * scale_factors
else:
grads = info["means2d"].grad.clone()
grads[..., 0] *= info["width"] / 2.0 * info["n_cameras"]
grads[..., 1] *= info["height"] / 2.0 * info["n_cameras"]
grads = info["means2d"].grad.detach() * scale_factors

# initialize state on the first run
# Initialize state on the first run
n_gaussian = params["anchors"].shape[0]
device = grads.device
if state["grad2d"] is None:
state["grad2d"] = torch.zeros(n_gaussian * self.n_feat_offsets, device=grads.device)
state["grad2d"] = torch.zeros(n_gaussian * self.n_feat_offsets, device=device)
if state["count"] is None:
state["count"] = torch.zeros(n_gaussian * self.n_feat_offsets, device=grads.device)
state["count"] = torch.zeros(n_gaussian * self.n_feat_offsets, device=device)
if state["anchor_count"] is None:
state["anchor_count"] = torch.zeros(n_gaussian, device=device)
if state["anchor_opacity"] is None:
state["anchor_opacity"] = torch.zeros(n_gaussian, device=device)
if self.refine_scale2d_stop_iter > 0 and state["radii"] is None:
assert "radii" in info, "radii is required but missing."
state["radii"] = torch.zeros(n_gaussian * self.n_feat_offsets, device=grads.device)

state["radii"] = torch.zeros(n_gaussian * self.n_feat_offsets, device=device)

# update the running state
# Update the running state
if packed:
# grads is [nnz, 2]
gs_ids = info["gaussian_ids"] # [nnz]
radii = info["radii"] # [nnz]
else:
# grads is [C, N, 2]
sel = info["radii"] > 0.0 # [C, N]
gs_ids = torch.where(sel)[1] # [nnz]
gs_ids = sel.nonzero(as_tuple=False)[:, 1] # [nnz]
grads = grads[sel] # [nnz, 2]
radii = info["radii"][sel] # [nnz]
# update neural gaussian statis

# Compute valid_mask efficiently
visible_anchor_mask = info["visible_anchor_mask"]
neural_selection_mask = info["neural_selection_mask"]
# Extend to
anchor_visible_mask = visible_anchor_mask.unsqueeze(dim=1).repeat([1, self.n_feat_offsets]).view(-1)
neural_gaussian_mask = torch.zeros_like(state["grad2d"], dtype=torch.bool)
neural_gaussian_mask[anchor_visible_mask] = neural_selection_mask
valid_mask = neural_gaussian_mask[gs_ids]

# Filter gs_ids and grads based on the valid_mask
# Compute anchor indices and offset indices for gs_ids
anchor_indices = gs_ids // self.n_feat_offsets
offset_indices = gs_ids % self.n_feat_offsets

# Determine valid gs_ids based on visibility and selection masks
valid_mask = (
visible_anchor_mask[anchor_indices] & neural_selection_mask[gs_ids]
)

# Filter gs_ids and grads based on valid_mask
valid_gs_ids = gs_ids[valid_mask]
valid_grads_norm = grads.norm(dim=-1)[valid_mask]
valid_grads_norm = grads[valid_mask].norm(dim=-1)

# Update state using index_add_
state["grad2d"].index_add_(0, valid_gs_ids, valid_grads_norm)
state["count"].index_add_(0, valid_gs_ids, torch.ones_like(valid_gs_ids, dtype=torch.float32))

state["count"].index_add_(
0, valid_gs_ids, torch.ones_like(valid_gs_ids, dtype=torch.float32)
)

# Update anchor opacity and count
anchor_ids = visible_anchor_mask.nonzero(as_tuple=False).squeeze(-1)
neural_opacities = (
info["neural_opacities"]
.detach()
.view(-1, self.n_feat_offsets)
.clamp_min_(0)
.sum(dim=1)
)

state["anchor_opacity"].index_add_(0, anchor_ids, neural_opacities)
state["anchor_count"].index_add_(
0, anchor_ids, torch.ones_like(anchor_ids, dtype=torch.float32)
)

# Update radii if required
if self.refine_scale2d_stop_iter > 0:
# Should be ideally using scatter max
# Normalize radii to [0, 1] screen space
normalized_radii = radii / float(max(info["width"], info["height"]))
# Update radii using torch.maximum
state["radii"][gs_ids] = torch.maximum(
state["radii"][gs_ids],
# normalize radii to [0, 1] screen space
radii / float(max(info["width"], info["height"])),
state["radii"][gs_ids], normalized_radii
)

@torch.no_grad()
Expand Down

0 comments on commit d7140a4

Please sign in to comment.