diff --git a/.gitignore b/.gitignore index 65a25f051..9ccfebea6 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,9 @@ compile_commands.json # Visual Studio Code configs. .vscode/ +# Pycharm +.idea + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/examples/benchmarks/scaffold.sh b/examples/benchmarks/scaffold.sh new file mode 100644 index 000000000..6358077f5 --- /dev/null +++ b/examples/benchmarks/scaffold.sh @@ -0,0 +1,53 @@ +SCENE_DIR="data/360_v2" +RESULT_DIR="results/benchmark" +SCENE_LIST="garden bicycle stump bonsai counter kitchen room" # treehill flowers +RENDER_TRAJ_PATH="ellipse" + +for SCENE in $SCENE_LIST; +do + if [ "$SCENE" = "bonsai" ] || [ "$SCENE" = "counter" ] || [ "$SCENE" = "kitchen" ] || [ "$SCENE" = "room" ]; then + DATA_FACTOR=2 + else + DATA_FACTOR=4 + fi + + echo "Running $SCENE" + + # train without eval + CUDA_VISIBLE_DEVICES=0 python simple_trainer_scaffold.py --eval_steps -1 --disable_viewer --data_factor $DATA_FACTOR \ + --render_traj_path $RENDER_TRAJ_PATH \ + --data_dir data/360_v2/$SCENE/ \ + --result_dir $RESULT_DIR/$SCENE/ + + # run eval and render + for CKPT in $RESULT_DIR/$SCENE/ckpts/*; + do + CUDA_VISIBLE_DEVICES=0 python simple_trainer_scaffold.py --disable_viewer --data_factor $DATA_FACTOR \ + --render_traj_path $RENDER_TRAJ_PATH \ + --data_dir data/360_v2/$SCENE/ \ + --result_dir $RESULT_DIR/$SCENE/ \ + --ckpt $CKPT + done +done + + +for SCENE in $SCENE_LIST; +do + echo "=== Eval Stats ===" + + for STATS in $RESULT_DIR/$SCENE/stats/val*.json; + do + echo $STATS + cat $STATS; + echo + done + + echo "=== Train Stats ===" + + for STATS in $RESULT_DIR/$SCENE/stats/train*_rank0.json; + do + echo $STATS + cat $STATS; + echo + done +done diff --git a/examples/simple_trainer_scaffold.py b/examples/simple_trainer_scaffold.py new file mode 100644 index 000000000..45af847c3 --- /dev/null +++ b/examples/simple_trainer_scaffold.py @@ -0,0 +1,1147 @@ +import json +import math +import os +import time +from dataclasses import dataclass, field +from collections import defaultdict +from typing import Dict, List, Optional, Tuple + +import imageio +import nerfview +import numpy as np +import torch +import torch.nn.functional as F +import tqdm +import tyro +import viser +import yaml +from torch.nn import ModuleDict, ParameterDict +from torch.optim import SparseAdam, Adam + +from datasets.colmap import Dataset, Parser +from datasets.traj import ( + generate_interpolated_path, + generate_ellipse_path_z, + generate_spiral_path, +) +from torch import Tensor +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure +from fused_ssim import fused_ssim +from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity +from typing_extensions import Literal +from utils import CameraOptModule, knn, set_random_seed +from lib_bilagrid import ( + BilateralGrid, + slice, + color_correct, + total_variation_loss, +) + +from gsplat.distributed import cli +from gsplat.rendering import rasterization, view_to_visible_anchors +from gsplat.strategy import ScaffoldStrategy + + +@dataclass +class Config: + # Disable viewer + disable_viewer: bool = False + # Path to the .pt files. If provide, it will skip training and run evaluation only. + ckpt: Optional[List[str]] = None + # Render trajectory path + render_traj_path: str = "ellipse" + + # Path to the Mip-NeRF 360 dataset + data_dir: str = "examples/data/360_v2/garden" + # Downsample factor for the dataset + data_factor: int = 4 + # Directory to save results + result_dir: str = "results" + # Every N images there is a test image + test_every: int = 8 + # Random crop size for training (experimental) + patch_size: Optional[int] = None + # A global scaler that applies to the scene size related parameters + global_scale: float = 1.0 + # Normalize the world space + normalize_world_space: bool = True + # Dimensionality of anchor features + feat_dim: int = 128 + # Number offsets + n_feat_offsets: int = 10 + + # Port for the viewer server + port: int = 8080 + + # Batch size for training. Learning rates are scaled automatically + batch_size: int = 1 + # A global factor to scale the number of training steps + steps_scaler: float = 1.0 + + # Number of training steps + max_steps: int = 30_000 + # Steps to evaluate the model + eval_steps: List[int] = field(default_factory=lambda: [7_000, 30_000]) + # Steps to save the model + save_steps: List[int] = field(default_factory=lambda: [7_000, 30_000]) + + # voxel size for Scaffold-GS + voxel_size = 0.001 + # Initial extent of GSs as a multiple of the camera extent. Ignored if using sfm + init_extent: float = 3.0 + # Turn on another SH degree every this steps + init_opa: float = 0.1 + # Initial scale of GS + init_scale: float = 1.0 + # Weight for SSIM loss + ssim_lambda: float = 0.2 + + # Near plane clipping distance + near_plane: float = 0.01 + # Far plane clipping distance + far_plane: float = 1e10 + + # Strategy for GS densification + strategy: ScaffoldStrategy = field(default_factory=ScaffoldStrategy) + # Use packed mode for rasterization, this leads to less memory usage but slightly slower. + packed: bool = False + # Use sparse gradients for optimization. (experimental) + sparse_grad: bool = False + # Anti-aliasing in rasterization. Might slightly hurt quantitative metrics. + antialiased: bool = False + + # Use random background for training to discourage transparency + random_bkgd: bool = False + + # Scale regularization + scale_reg: float = 0.001 + + # Enable camera optimization. + pose_opt: bool = False + # Learning rate for camera optimization + pose_opt_lr: float = 1e-5 + # Regularization for camera optimization as weight decay + pose_opt_reg: float = 1e-6 + # Add noise to camera extrinsics. This is only to test the camera pose optimization. + pose_noise: float = 0.0 + + # Enable bilateral grid. (experimental) + use_bilateral_grid: bool = False + # Shape of the bilateral grid (X, Y, W) + bilateral_grid_shape: Tuple[int, int, int] = (16, 16, 8) + + # Enable depth loss. (experimental) + depth_loss: bool = False + # Weight for depth loss + depth_lambda: float = 1e-2 + + # Dump information to tensorboard every this steps + tb_every: int = 100 + # Save training images to tensorboard + tb_save_image: bool = False + + lpips_net: Literal["vgg", "alex"] = "alex" + + def adjust_steps(self, factor: float): + self.eval_steps = [int(i * factor) for i in self.eval_steps] + self.save_steps = [int(i * factor) for i in self.save_steps] + self.max_steps = int(self.max_steps * factor) + + strategy = self.strategy + strategy.refine_start_iter = int(strategy.refine_start_iter * factor) + strategy.refine_stop_iter = int(strategy.refine_stop_iter * factor) + strategy.voxel_size = self.voxel_size + + +def create_splats_with_optimizers( + parser: Parser, + init_extent: float = 3.0, + init_opacity: float = 0.1, + init_scale: float = 1.0, + scene_scale: float = 1.0, + sparse_grad: bool = False, + batch_size: int = 1, + feature_dim: Optional[int] = None, + device: str = "cuda", + world_rank: int = 0, + world_size: int = 1, + voxel_size: float = 0.001, +) -> tuple[ + dict[str, ModuleDict | ParameterDict], dict[str, dict[str, SparseAdam | Adam]] +]: + + # Compare GS-Scaffold paper formula (4) + points = np.unique(np.round(parser.points / voxel_size), axis=0) * voxel_size + points = torch.from_numpy(points).float() + + # Initialize the GS size to be the average dist of the 3 nearest neighbors + dist2_avg = (knn(points, 4)[:, 1:] ** 2).mean(dim=-1) # [N,] + dist_avg = torch.sqrt(dist2_avg) + scales = torch.log(dist_avg * init_scale).unsqueeze(-1).repeat(1, 6) # [N, 3] + + # Distribute the GSs to different ranks (also works for single rank) + points = points[world_rank::world_size] + scales = scales[world_rank::world_size] + N = points.shape[0] + quats = torch.rand((N, 4)) # [N, 4] + + features = torch.zeros((N, cfg.feat_dim)) + offsets = torch.zeros((N, cfg.n_feat_offsets, 3)) + + opacities = torch.logit(torch.full((N, 1), init_opacity)) # [N,] + + # Define learning rates for gauss_params and decoders + learning_rates = { + "anchors": 0.0001, + "scales": 0.007, + "quats": 0.002, + "features": 0.0075, + "opacities": 5e-2, + "offsets": 0.01 * scene_scale, + "opacities_mlp": 0.002, + "colors_mlp": 0.008, + "scale_rot_mlp": 0.0004, + } + + # Define gauss_params + gauss_params = torch.nn.ParameterDict( + { + "anchors": torch.nn.Parameter(points), + "scales": torch.nn.Parameter(scales), + "quats": torch.nn.Parameter(quats), + "features": torch.nn.Parameter(features), + "opacities": torch.nn.Parameter(opacities), + "offsets": torch.nn.Parameter(offsets), + } + ).to(device) + + # Define the MLPs (decoders) + colors_mlp: torch.nn.Sequential = torch.nn.Sequential( + torch.nn.Linear(cfg.feat_dim + 3, cfg.feat_dim), + torch.nn.ReLU(True), + torch.nn.Linear(cfg.feat_dim, 3 * cfg.n_feat_offsets), + torch.nn.Sigmoid(), + ).cuda() + + opacities_mlp: torch.nn.Sequential = torch.nn.Sequential( + torch.nn.Linear(cfg.feat_dim + 3, cfg.feat_dim), + torch.nn.ReLU(True), + torch.nn.Linear(cfg.feat_dim, cfg.n_feat_offsets), + torch.nn.Tanh(), + ).cuda() + + scale_rot_mlp: torch.nn.Sequential = torch.nn.Sequential( + torch.nn.Linear(cfg.feat_dim + 3, cfg.feat_dim), + torch.nn.ReLU(True), + torch.nn.Linear(cfg.feat_dim, 7 * cfg.n_feat_offsets), + ).cuda() + + # Initialize decoders (MLPs) + decoders = torch.nn.ModuleDict( + { + "opacities_mlp": opacities_mlp, + "colors_mlp": colors_mlp, + "scale_rot_mlp": scale_rot_mlp, + } + ).to(device) + + # Scale learning rates based on batch size (BS) + BS = batch_size * world_size + + # Create optimizers for gauss_params + gauss_optimizers = { + name: (torch.optim.SparseAdam if sparse_grad else torch.optim.Adam)( + [{"params": param, "lr": learning_rates[name] * math.sqrt(BS)}], + eps=1e-15 / math.sqrt(BS), + betas=(1 - BS * (1 - 0.9), 1 - BS * (1 - 0.999)), + ) + for name, param in gauss_params.items() + } + + # Create optimizers for decoders + decoders_optimizers = { + name: (torch.optim.SparseAdam if sparse_grad else torch.optim.Adam)( + [ + { + "params": decoder.parameters(), + "lr": learning_rates[name] * math.sqrt(BS), + } + ], + eps=1e-15 / math.sqrt(BS), + betas=(1 - BS * (1 - 0.9), 1 - BS * (1 - 0.999)), + ) + for name, decoder in decoders.items() + } + + # Combine gauss_params and decoders optimizers into a dictionary of dictionaries + optimizers = { + "gauss_optimizer": gauss_optimizers, + "decoders_optimizer": decoders_optimizers, + } + + # Return the gauss_params, decoders, and the dictionary of dictionaries for optimizers + splats = {"gauss_params": gauss_params, "decoders": decoders} + return splats, optimizers + + +class Runner: + """Engine for training and testing.""" + + def __init__( + self, local_rank: int, world_rank, world_size: int, cfg: Config + ) -> None: + set_random_seed(42 + local_rank) + + self.cfg = cfg + self.cfg.strategy.voxel_size = self.cfg.voxel_size + self.world_rank = world_rank + self.local_rank = local_rank + self.world_size = world_size + self.device = f"cuda:{local_rank}" + + # Where to dump results. + os.makedirs(cfg.result_dir, exist_ok=True) + + # Setup output directories. + self.ckpt_dir = f"{cfg.result_dir}/ckpts" + os.makedirs(self.ckpt_dir, exist_ok=True) + self.stats_dir = f"{cfg.result_dir}/stats" + os.makedirs(self.stats_dir, exist_ok=True) + self.render_dir = f"{cfg.result_dir}/renders" + os.makedirs(self.render_dir, exist_ok=True) + + # Tensorboard + self.writer = SummaryWriter(log_dir=f"{cfg.result_dir}/tb") + + # Load data: Training data should contain initial points and colors. + self.parser = Parser( + data_dir=cfg.data_dir, + factor=cfg.data_factor, + normalize=cfg.normalize_world_space, + test_every=cfg.test_every, + ) + self.trainset = Dataset( + self.parser, + split="train", + patch_size=cfg.patch_size, + load_depths=cfg.depth_loss, + ) + self.valset = Dataset(self.parser, split="val") + self.scene_scale = self.parser.scene_scale * 1.1 * cfg.global_scale + print("Scene scale:", self.scene_scale) + + # Model + self.splats, self.optimizers = create_splats_with_optimizers( + self.parser, + init_extent=cfg.init_extent, + init_opacity=cfg.init_opa, + init_scale=cfg.init_scale, + scene_scale=self.scene_scale, + sparse_grad=cfg.sparse_grad, + batch_size=cfg.batch_size, + device=self.device, + world_rank=world_rank, + world_size=world_size, + voxel_size=cfg.voxel_size, + ) + print( + "Model initialized. Number of GS:", + len(self.splats["gauss_params"]["anchors"]), + ) + + # Densification Strategy + self.cfg.strategy.check_sanity( + self.splats["gauss_params"], self.optimizers["gauss_optimizer"] + ) + + self.strategy_state = self.cfg.strategy.initialize_state( + scene_scale=self.scene_scale, + feat_dim=cfg.feat_dim, + n_feat_offsets=cfg.n_feat_offsets, + ) + + self.pose_optimizers = [] + if cfg.pose_opt: + self.pose_adjust = CameraOptModule(len(self.trainset)).to(self.device) + self.pose_adjust.zero_init() + self.pose_optimizers = [ + torch.optim.Adam( + self.pose_adjust.parameters(), + lr=cfg.pose_opt_lr * math.sqrt(cfg.batch_size), + weight_decay=cfg.pose_opt_reg, + ) + ] + if world_size > 1: + self.pose_adjust = DDP(self.pose_adjust) + + if cfg.pose_noise > 0.0: + self.pose_perturb = CameraOptModule(len(self.trainset)).to(self.device) + self.pose_perturb.random_init(cfg.pose_noise) + if world_size > 1: + self.pose_perturb = DDP(self.pose_perturb) + + self.bil_grid_optimizers = [] + if cfg.use_bilateral_grid: + self.bil_grids = BilateralGrid( + len(self.trainset), + grid_X=cfg.bilateral_grid_shape[0], + grid_Y=cfg.bilateral_grid_shape[1], + grid_W=cfg.bilateral_grid_shape[2], + ).to(self.device) + self.bil_grid_optimizers = [ + torch.optim.Adam( + self.bil_grids.parameters(), + lr=2e-3 * math.sqrt(cfg.batch_size), + eps=1e-15, + ), + ] + + # Losses & Metrics. + self.ssim = StructuralSimilarityIndexMeasure(data_range=1.0).to(self.device) + self.psnr = PeakSignalNoiseRatio(data_range=1.0).to(self.device) + + if cfg.lpips_net == "alex": + self.lpips = LearnedPerceptualImagePatchSimilarity( + net_type="alex", normalize=True + ).to(self.device) + elif cfg.lpips_net == "vgg": + # The 3DGS official repo uses lpips vgg, which is equivalent with the following: + self.lpips = LearnedPerceptualImagePatchSimilarity( + net_type="vgg", normalize=False + ).to(self.device) + else: + raise ValueError(f"Unknown LPIPS network: {cfg.lpips_net}") + + # Viewer + if not self.cfg.disable_viewer: + self.server = viser.ViserServer(port=cfg.port, verbose=False) + self.viewer = nerfview.Viewer( + server=self.server, + render_fn=self._viewer_render_fn, + mode="training", + ) + + def get_neural_gaussians( + self, + camtoworlds: Tensor, + Ks: Tensor, + width: int, + height: int, + packed: bool, + rasterize_mode: str, + ): + + # Compare paper: Helps mainly to speed up the rasterization. Has no quality impact + visible_anchor_mask = view_to_visible_anchors( + means=self.splats["gauss_params"]["anchors"], + quats=self.splats["gauss_params"]["quats"], + scales=torch.exp(self.splats["gauss_params"]["scales"][:, :3]), + viewmats=torch.linalg.inv(camtoworlds), # [C, 4, 4] + Ks=Ks, # [C, 3, 3] + width=width, + height=height, + packed=packed, + rasterize_mode=rasterize_mode, + ) + # If no visibility mask is provided, we select all anchors including their offsets + selected_features = self.splats["gauss_params"]["features"][ + visible_anchor_mask + ] # [M, c] + selected_anchors = self.splats["gauss_params"]["anchors"][ + visible_anchor_mask + ] # [M, 3] + selected_offsets = self.splats["gauss_params"]["offsets"][ + visible_anchor_mask + ] # [M, k, 3] + selected_quats = self.splats["gauss_params"]["quats"][ + visible_anchor_mask + ] # [M, 4] + selected_scales = torch.exp( + self.splats["gauss_params"]["scales"][visible_anchor_mask] + ) # [M, 6] + + # See formula (5) in Scaffold-GS + + cam_pos = camtoworlds[:, :3, 3] + view_dir = selected_anchors - cam_pos # [M, 3] + length = view_dir.norm(dim=1, keepdim=True) + view_dir_normalized = view_dir / length # [M, 3] + + # See formula (9) and the appendix for the rest + feature_view_dir = torch.cat( + [selected_features, view_dir_normalized], dim=1 + ) # [M, c+3] + + k = self.cfg.n_feat_offsets # Number of offsets per anchor + + # Apply MLPs (they output per-offset features concatenated along the last dimension) + neural_opacity = self.splats["decoders"]["opacities_mlp"]( + feature_view_dir + ) # [M, k*1] + neural_opacity = neural_opacity.view(-1, 1) # [M*k, 1] + neural_selection_mask = (neural_opacity > 0.0).view(-1) # [M*k] + + # Get color and reshape + neural_colors = self.splats["decoders"]["colors_mlp"]( + feature_view_dir + ) # [M, k*3] + neural_colors = neural_colors.view(-1, 3) # [M*k, 3] + + # Get scale and rotation and reshape + neural_scale_rot = self.splats["decoders"]["scale_rot_mlp"]( + feature_view_dir + ) # [M, k*7] + neural_scale_rot = neural_scale_rot.view(-1, 7) # [M*k, 7] + + # Reshape selected_offsets, scales, and anchors + selected_offsets = selected_offsets.view(-1, 3) # [M*k, 3] + scales_repeated = ( + selected_scales.unsqueeze(1).repeat(1, k, 1).view(-1, 6) + ) # [M*k, 6] + anchors_repeated = ( + selected_anchors.unsqueeze(1).repeat(1, k, 1).view(-1, 3) + ) # [M*k, 3] + quats_repeated = ( + selected_quats.unsqueeze(1).repeat(1, k, 1).view(-1, 4) + ) # [M*k, 3] + + # Apply positive opacity mask + selected_opacity = neural_opacity[neural_selection_mask].squeeze(-1) # [M] + selected_colors = neural_colors[neural_selection_mask] # [M, 3] + selected_scale_rot = neural_scale_rot[neural_selection_mask] # [M, 7] + selected_offsets = selected_offsets[neural_selection_mask] # [M, 3] + scales_repeated = scales_repeated[neural_selection_mask] # [M, 6] + anchors_repeated = anchors_repeated[neural_selection_mask] # [M, 3] + quats_repeated = quats_repeated[neural_selection_mask] # [M, 3] + + # Compute scales and rotations + scales = scales_repeated[:, 3:] * torch.sigmoid( + selected_scale_rot[:, :3] + ) # [M, 3] + + def quaternion_multiply(q1, q2): + # Extract individual components of the quaternions + w1, x1, y1, z1 = q1[:, 0], q1[:, 1], q1[:, 2], q1[:, 3] + w2, x2, y2, z2 = q2[:, 0], q2[:, 1], q2[:, 2], q2[:, 3] + + # Perform the quaternion multiplication + w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2 + x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2 + y = w1 * y2 + y1 * w2 + z1 * x2 - x1 * z2 + z = w1 * z2 + z1 * w2 + x1 * y2 - y1 * x2 + + # Stack the result back into shape (N, 4) + return torch.stack((w, x, y, z), dim=-1) + + # The rasterizer takes care of the normalization + rotation = quaternion_multiply(quats_repeated, selected_scale_rot[:, 3:7]) + + # Compute offsets and anchors + offsets = selected_offsets * scales_repeated[:, :3] # [M, 3] + means = anchors_repeated + offsets # [M, 3] + + v_a = ( + visible_anchor_mask.unsqueeze(dim=1) + .repeat([1, self.cfg.n_feat_offsets]) + .view(-1) + ) + all_neural_gaussians = torch.zeros_like(v_a, dtype=torch.bool) + all_neural_gaussians[v_a] = neural_selection_mask + + info = { + "means": means, + "colors": selected_colors, + "opacities": selected_opacity, + "scales": scales, + "quats": rotation, + "neural_opacities": neural_opacity, + "neural_selection_mask": all_neural_gaussians, + "visible_anchor_mask": visible_anchor_mask, + } + return info + + def rasterize_splats( + self, + camtoworlds: Tensor, + Ks: Tensor, + width: int, + height: int, + **kwargs, + ) -> Tuple[Tensor, Tensor, Dict]: + + # Get all the gaussians per voxel spawned from the anchors + info = self.get_neural_gaussians( + camtoworlds=camtoworlds, + Ks=Ks, + width=width, + height=height, + packed=self.cfg.packed, + rasterize_mode="antialiased" if self.cfg.antialiased else "classic", + ) + + colors = info["colors"] # [N, K, 3] + + rasterize_mode = "antialiased" if self.cfg.antialiased else "classic" + render_colors, render_alphas, raster_info = rasterization( + means=info["means"], + quats=info["quats"], + scales=info["scales"], + opacities=info["opacities"], + colors=colors, + viewmats=torch.linalg.inv(camtoworlds), # [C, 4, 4] + Ks=Ks, # [C, 3, 3] + width=width, + height=height, + packed=self.cfg.packed, + absgrad=self.cfg.strategy.absgrad, + sparse_grad=self.cfg.sparse_grad, + rasterize_mode=rasterize_mode, + distributed=self.world_size > 1, + **kwargs, + ) + raster_info.update(info) + return render_colors, render_alphas, raster_info + + def train(self): + cfg = self.cfg + device = self.device + world_rank = self.world_rank + world_size = self.world_size + + # Dump cfg. + if world_rank == 0: + with open(f"{cfg.result_dir}/cfg.yml", "w") as f: + yaml.dump(vars(cfg), f) + + max_steps = cfg.max_steps + init_step = 0 + + schedulers = [ + torch.optim.lr_scheduler.ExponentialLR( + self.optimizers["gauss_optimizer"]["anchors"], + gamma=0.001 ** (1.0 / max_steps), + ), + torch.optim.lr_scheduler.ExponentialLR( + self.optimizers["gauss_optimizer"]["offsets"], + gamma=(0.01 * self.scene_scale) ** (1.0 / max_steps), + ), + torch.optim.lr_scheduler.ExponentialLR( + self.optimizers["decoders_optimizer"]["opacities_mlp"], + gamma=0.001 ** (1.0 / max_steps), + ), + torch.optim.lr_scheduler.ExponentialLR( + self.optimizers["decoders_optimizer"]["colors_mlp"], + gamma=0.00625 ** (1.0 / max_steps), + ), + torch.optim.lr_scheduler.ExponentialLR( + self.optimizers["decoders_optimizer"]["scale_rot_mlp"], gamma=1.0 + ), + ] + if cfg.pose_opt: + # pose optimization has a learning rate schedule + schedulers.append( + torch.optim.lr_scheduler.ExponentialLR( + self.pose_optimizers[0], gamma=0.01 ** (1.0 / max_steps) + ) + ) + if cfg.use_bilateral_grid: + # bilateral grid has a learning rate schedule. Linear warmup for 1000 steps. + schedulers.append( + torch.optim.lr_scheduler.ChainedScheduler( + [ + torch.optim.lr_scheduler.LinearLR( + self.bil_grid_optimizers[0], + start_factor=0.01, + total_iters=1000, + ), + torch.optim.lr_scheduler.ExponentialLR( + self.bil_grid_optimizers[0], gamma=0.01 ** (1.0 / max_steps) + ), + ] + ) + ) + + trainloader = torch.utils.data.DataLoader( + self.trainset, + batch_size=cfg.batch_size, + shuffle=True, + num_workers=4, + persistent_workers=True, + pin_memory=True, + ) + trainloader_iter = iter(trainloader) + + # Training loop. + global_tic = time.time() + pbar = tqdm.tqdm(range(init_step, max_steps)) + + self.splats["decoders"]["scale_rot_mlp"].train() + self.splats["decoders"]["opacities_mlp"].train() + self.splats["decoders"]["colors_mlp"].train() + + for step in pbar: + if not cfg.disable_viewer: + while self.viewer.state.status == "paused": + time.sleep(0.01) + self.viewer.lock.acquire() + tic = time.time() + + try: + data = next(trainloader_iter) + except StopIteration: + trainloader_iter = iter(trainloader) + data = next(trainloader_iter) + + camtoworlds = camtoworlds_gt = data["camtoworld"].to(device) # [1, 4, 4] + Ks = data["K"].to(device) # [1, 3, 3] + pixels = data["image"].to(device) / 255.0 # [1, H, W, 3] + num_train_rays_per_step = ( + pixels.shape[0] * pixels.shape[1] * pixels.shape[2] + ) + image_ids = data["image_id"].to(device) + if cfg.depth_loss: + points = data["points"].to(device) # [1, M, 2] + depths_gt = data["depths"].to(device) # [1, M] + + height, width = pixels.shape[1:3] + + if cfg.pose_noise: + camtoworlds = self.pose_perturb(camtoworlds, image_ids) + + if cfg.pose_opt: + camtoworlds = self.pose_adjust(camtoworlds, image_ids) + + # forward + renders, alphas, info = self.rasterize_splats( + camtoworlds=camtoworlds, + Ks=Ks, + width=width, + height=height, + sh_degree=None, + near_plane=cfg.near_plane, + far_plane=cfg.far_plane, + render_mode="RGB+ED" if cfg.depth_loss else "RGB", + ) + if renders.shape[-1] == 4: + colors, depths = renders[..., 0:3], renders[..., 3:4] + else: + colors, depths = renders, None + + if cfg.use_bilateral_grid: + grid_y, grid_x = torch.meshgrid( + (torch.arange(height, device=self.device) + 0.5) / height, + (torch.arange(width, device=self.device) + 0.5) / width, + indexing="ij", + ) + grid_xy = torch.stack([grid_x, grid_y], dim=-1).unsqueeze(0) + colors = slice(self.bil_grids, grid_xy, colors, image_ids)["rgb"] + + if cfg.random_bkgd: + bkgd = torch.rand(1, 3, device=device) + colors = colors + bkgd * (1.0 - alphas) + + self.cfg.strategy.step_pre_backward( + params=self.splats["gauss_params"], + optimizers=self.optimizers["gauss_optimizer"], + state=self.strategy_state, + step=step, + info=info, + ) + + # loss + l1loss = F.l1_loss(colors, pixels) + ssimloss = 1.0 - fused_ssim( + colors.permute(0, 3, 1, 2), pixels.permute(0, 3, 1, 2), padding="valid" + ) + loss = l1loss * (1.0 - cfg.ssim_lambda) + loss += ssimloss * cfg.ssim_lambda + loss += info["scales"].prod(dim=1).mean() * cfg.scale_reg + + if cfg.depth_loss: + # query depths from depth map + points = torch.stack( + [ + points[:, :, 0] / (width - 1) * 2 - 1, + points[:, :, 1] / (height - 1) * 2 - 1, + ], + dim=-1, + ) # normalize to [-1, 1] + grid = points.unsqueeze(2) # [1, M, 1, 2] + depths = F.grid_sample( + depths.permute(0, 3, 1, 2), grid, align_corners=True + ) # [1, 1, M, 1] + depths = depths.squeeze(3).squeeze(1) # [1, M] + # calculate loss in disparity space + disp = torch.where(depths > 0.0, 1.0 / depths, torch.zeros_like(depths)) + disp_gt = 1.0 / depths_gt # [1, M] + depthloss = F.l1_loss(disp, disp_gt) * self.scene_scale + loss += depthloss * cfg.depth_lambda + if cfg.use_bilateral_grid: + tvloss = 10 * total_variation_loss(self.bil_grids.grids) + loss += tvloss + + loss.backward() + + desc = f"loss={loss.item():.3f}| " + if cfg.depth_loss: + desc += f"depth loss={depthloss.item():.6f}| " + if cfg.pose_opt and cfg.pose_noise: + # monitor the pose error if we inject noise + pose_err = F.l1_loss(camtoworlds_gt, camtoworlds) + desc += f"pose err={pose_err.item():.6f}| " + pbar.set_description(desc) + + if world_rank == 0 and cfg.tb_every > 0 and step % cfg.tb_every == 0: + mem = torch.cuda.max_memory_allocated() / 1024**3 + self.writer.add_scalar("train/loss", loss.item(), step) + self.writer.add_scalar("train/l1loss", l1loss.item(), step) + self.writer.add_scalar("train/ssimloss", ssimloss.item(), step) + self.writer.add_scalar( + "train/num_GS", len(self.splats["gauss_params"]["anchors"]), step + ) + self.writer.add_scalar("train/mem", mem, step) + if cfg.depth_loss: + self.writer.add_scalar("train/depthloss", depthloss.item(), step) + if cfg.use_bilateral_grid: + self.writer.add_scalar("train/tvloss", tvloss.item(), step) + if cfg.tb_save_image: + canvas = torch.cat([pixels, colors], dim=2).detach().cpu().numpy() + canvas = canvas.reshape(-1, *canvas.shape[2:]) + self.writer.add_image("train/render", canvas, step) + self.writer.flush() + + # save checkpoint before updating the model + if step in [i - 1 for i in cfg.save_steps] or step == max_steps - 1: + mem = torch.cuda.max_memory_allocated() / 1024**3 + stats = { + "mem": mem, + "ellipse_time": time.time() - global_tic, + "num_GS": len(self.splats["gauss_params"]["anchors"]), + } + print("Step: ", step, stats) + with open( + f"{self.stats_dir}/train_step{step:04d}_rank{self.world_rank}.json", + "w", + ) as f: + json.dump(stats, f) + data = { + "step": step, + "feat_dim": self.cfg.feat_dim, + "n_feat_offsets": self.cfg.n_feat_offsets, + "gauss_params": self.splats["gauss_params"].state_dict(), + "opacities_mlp": self.splats["decoders"][ + "opacities_mlp" + ].state_dict(), + "colors_mlp": self.splats["decoders"]["colors_mlp"].state_dict(), + "scale_rot_mlp": self.splats["decoders"][ + "scale_rot_mlp" + ].state_dict(), + } + if cfg.pose_opt: + if world_size > 1: + data["pose_adjust"] = self.pose_adjust.module.state_dict() + else: + data["pose_adjust"] = self.pose_adjust.state_dict() + torch.save( + data, f"{self.ckpt_dir}/ckpt_{step}_rank{self.world_rank}.pt" + ) + + # For now no post steps + self.cfg.strategy.step_post_backward( + params=self.splats["gauss_params"], + optimizers=self.optimizers["gauss_optimizer"], + state=self.strategy_state, + step=step, + info=info, + packed=cfg.packed, + ) + + # Turn Gradients into Sparse Tensor before running optimizer + if cfg.sparse_grad: + assert cfg.packed, "Sparse gradients only work with packed mode." + gaussian_ids = info["gaussian_ids"] + for k in self.splats["gauss_params"].keys(): + grad = self.splats["gauss_params"][k].grad + if grad is None or grad.is_sparse: + continue + self.splats["gauss_params"][k].grad = torch.sparse_coo_tensor( + indices=gaussian_ids[None], # [1, nnz] + values=grad[gaussian_ids], # [nnz, ...] + size=self.splats["gauss_params"][k].size(), # [N, ...] + is_coalesced=len(Ks) == 1, + ) + + # optimize + for optimizer in self.optimizers["gauss_optimizer"].values(): + optimizer.step() + optimizer.zero_grad(set_to_none=True) + # optimize + for optimizer in self.optimizers["decoders_optimizer"].values(): + optimizer.step() + optimizer.zero_grad(set_to_none=True) + for optimizer in self.pose_optimizers: + optimizer.step() + optimizer.zero_grad(set_to_none=True) + for optimizer in self.bil_grid_optimizers: + optimizer.step() + optimizer.zero_grad(set_to_none=True) + for scheduler in schedulers: + scheduler.step() + + # eval the full set + if step in [i - 1 for i in cfg.eval_steps]: + self.eval( + step, + n_feat_offsets=self.cfg.n_feat_offsets, + feat_dim=self.cfg.feat_dim, + ) + self.render_traj(step) + + if not cfg.disable_viewer: + self.viewer.lock.release() + num_train_steps_per_sec = 1.0 / (time.time() - tic) + num_train_rays_per_sec = ( + num_train_rays_per_step * num_train_steps_per_sec + ) + # Update the viewer state. + self.viewer.state.num_train_rays_per_sec = num_train_rays_per_sec + # Update the scene. + self.viewer.update(step, num_train_rays_per_step) + + @torch.no_grad() + def eval(self, step: int, n_feat_offsets: int, feat_dim: int, stage: str = "val"): + """Entry for evaluation.""" + print("Running evaluation...") + assert ( + n_feat_offsets == self.cfg.n_feat_offsets + ), f"Feature offset count changed, should be {n_feat_offsets}" + assert ( + feat_dim == self.cfg.feat_dim + ), f"Feature dim changed, should be {feat_dim}" + + cfg = self.cfg + device = self.device + world_rank = self.world_rank + world_size = self.world_size + + valloader = torch.utils.data.DataLoader( + self.valset, batch_size=1, shuffle=False, num_workers=1 + ) + ellipse_time = 0 + metrics = defaultdict(list) + for i, data in enumerate(valloader): + camtoworlds = data["camtoworld"].to(device) + Ks = data["K"].to(device) + pixels = data["image"].to(device) / 255.0 + height, width = pixels.shape[1:3] + + torch.cuda.synchronize() + tic = time.time() + colors, _, _ = self.rasterize_splats( + camtoworlds=camtoworlds, + Ks=Ks, + width=width, + height=height, + sh_degree=None, + near_plane=cfg.near_plane, + far_plane=cfg.far_plane, + ) # [1, H, W, 3] + torch.cuda.synchronize() + ellipse_time += time.time() - tic + + colors = torch.clamp(colors, 0.0, 1.0) + canvas_list = [pixels, colors] + + if world_rank == 0: + # write images + canvas = torch.cat(canvas_list, dim=2).squeeze(0).cpu().numpy() + canvas = (canvas * 255).astype(np.uint8) + imageio.imwrite( + f"{self.render_dir}/{stage}_step{step}_{i:04d}.png", + canvas, + ) + + pixels_p = pixels.permute(0, 3, 1, 2) # [1, 3, H, W] + colors_p = colors.permute(0, 3, 1, 2) # [1, 3, H, W] + metrics["psnr"].append(self.psnr(colors_p, pixels_p)) + metrics["ssim"].append(self.ssim(colors_p, pixels_p)) + metrics["lpips"].append(self.lpips(colors_p, pixels_p)) + if cfg.use_bilateral_grid: + cc_colors = color_correct(colors, pixels) + cc_colors_p = cc_colors.permute(0, 3, 1, 2) # [1, 3, H, W] + metrics["cc_psnr"].append(self.psnr(cc_colors_p, pixels_p)) + + if world_rank == 0: + ellipse_time /= len(valloader) + + stats = {k: torch.stack(v).mean().item() for k, v in metrics.items()} + stats.update( + { + "ellipse_time": ellipse_time, + "num_GS": len(self.splats["gauss_params"]["anchors"]), + } + ) + print( + f"PSNR: {stats['psnr']:.3f}, SSIM: {stats['ssim']:.4f}, LPIPS: {stats['lpips']:.3f} " + f"Time: {stats['ellipse_time']:.3f}s/image " + f"Number of GS: {stats['num_GS']}" + ) + # save stats as json + with open(f"{self.stats_dir}/{stage}_step{step:04d}.json", "w") as f: + json.dump(stats, f) + # save stats to tensorboard + for k, v in stats.items(): + self.writer.add_scalar(f"{stage}/{k}", v, step) + self.writer.flush() + + @torch.no_grad() + def render_traj(self, step: int): + """Entry for trajectory rendering.""" + print("Running trajectory rendering...") + cfg = self.cfg + device = self.device + + camtoworlds_all = self.parser.camtoworlds[5:-5] + if cfg.render_traj_path == "interp": + camtoworlds_all = generate_interpolated_path( + camtoworlds_all, 1 + ) # [N, 3, 4] + elif cfg.render_traj_path == "ellipse": + height = camtoworlds_all[:, 2, 3].mean() + camtoworlds_all = generate_ellipse_path_z( + camtoworlds_all, height=height + ) # [N, 3, 4] + elif cfg.render_traj_path == "spiral": + camtoworlds_all = generate_spiral_path( + camtoworlds_all, + bounds=self.parser.bounds * self.scene_scale, + spiral_scale_r=self.parser.extconf["spiral_radius_scale"], + ) + else: + raise ValueError( + f"Render trajectory type not supported: {cfg.render_traj_path}" + ) + + camtoworlds_all = np.concatenate( + [ + camtoworlds_all, + np.repeat( + np.array([[[0.0, 0.0, 0.0, 1.0]]]), len(camtoworlds_all), axis=0 + ), + ], + axis=1, + ) # [N, 4, 4] + + camtoworlds_all = torch.from_numpy(camtoworlds_all).float().to(device) + K = torch.from_numpy(list(self.parser.Ks_dict.values())[0]).float().to(device) + width, height = list(self.parser.imsize_dict.values())[0] + + canvas_all = [] + for i in tqdm.trange(len(camtoworlds_all), desc="Rendering trajectory"): + camtoworlds = camtoworlds_all[i : i + 1] + Ks = K[None] + + renders, _, _ = self.rasterize_splats( + camtoworlds=camtoworlds, + Ks=Ks, + width=width, + height=height, + sh_degree=None, + near_plane=cfg.near_plane, + far_plane=cfg.far_plane, + render_mode="RGB+ED", + ) # [1, H, W, 4] + colors = torch.clamp(renders[..., 0:3], 0.0, 1.0) # [1, H, W, 3] + depths = renders[..., 3:4] # [1, H, W, 1] + depths = (depths - depths.min()) / (depths.max() - depths.min()) + canvas_list = [colors, depths.repeat(1, 1, 1, 3)] + + # write images + canvas = torch.cat(canvas_list, dim=2).squeeze(0).cpu().numpy() + canvas = (canvas * 255).astype(np.uint8) + canvas_all.append(canvas) + + # save to video + video_dir = f"{cfg.result_dir}/videos" + os.makedirs(video_dir, exist_ok=True) + writer = imageio.get_writer(f"{video_dir}/traj_{step}.mp4", fps=24) + for canvas in canvas_all: + writer.append_data(canvas) + writer.close() + print(f"Video saved to {video_dir}/traj_{step}.mp4") + + @torch.no_grad() + def _viewer_render_fn( + self, camera_state: nerfview.CameraState, img_wh: Tuple[int, int] + ): + """Callable function for the viewer.""" + W, H = img_wh + c2w = camera_state.c2w + K = camera_state.get_K(img_wh) + c2w = torch.from_numpy(c2w).float().to(self.device) + K = torch.from_numpy(K).float().to(self.device) + + render_colors, _, _ = self.rasterize_splats( + camtoworlds=c2w[None], + Ks=K[None], + width=W, + height=H, + sh_degree=None, # active all SH degrees + radius_clip=3.0, # skip GSs that have small image radius (in pixels) + ) # [1, H, W, 3] + return render_colors[0].cpu().numpy() + + +def main(local_rank: int, world_rank, world_size: int, cfg: Config): + if world_size > 1 and not cfg.disable_viewer: + cfg.disable_viewer = True + if world_rank == 0: + print("Viewer is disabled in distributed training.") + + runner = Runner(local_rank, world_rank, world_size, cfg) + + if cfg.ckpt is not None: + # run eval only + ckpts = [ + torch.load(file, map_location=runner.device, weights_only=True) + for file in cfg.ckpt + ] + + for k in runner.splats["gauss_params"].keys(): + runner.splats["gauss_params"][k].data = torch.cat( + [ckpt["gauss_params"][k] for ckpt in ckpts] + ) + for k in runner.splats["decoders"].keys(): + runner.splats["decoders"][k].load_state_dict(ckpts[0][k]) + step = ckpts[0]["step"] + n_feat_offsets = ckpts[0]["n_feat_offsets"] + feat_dim = ckpts[0]["feat_dim"] + runner.eval(step=step, n_feat_offsets=n_feat_offsets, feat_dim=feat_dim) + runner.render_traj(step=step) + else: + runner.train() + + if not cfg.disable_viewer: + print("Viewer running... Ctrl+C to exit.") + time.sleep(1000000) + + +if __name__ == "__main__": + """ + Usage: + + ```bash + # Single GPU training + CUDA_VISIBLE_DEVICES=0 python simple_trainer.py default + + # Distributed training on 4 GPUs: Effectively 4x batch size so run 4x less steps. + CUDA_VISIBLE_DEVICES=0,1,2,3 python simple_trainer.py default --steps_scaler 0.25 + + """ + + cfg = tyro.cli(Config) + cfg.adjust_steps(cfg.steps_scaler) + + cli(main, cfg, verbose=True) diff --git a/gsplat/__init__.py b/gsplat/__init__.py index df47d1555..03ba23b40 100644 --- a/gsplat/__init__.py +++ b/gsplat/__init__.py @@ -23,7 +23,7 @@ rasterization_inria_wrapper, rasterization_2dgs_inria_wrapper, ) -from .strategy import DefaultStrategy, MCMCStrategy, Strategy +from .strategy import DefaultStrategy, MCMCStrategy, ScaffoldStrategy, Strategy from .version import __version__ all = [ diff --git a/gsplat/rendering.py b/gsplat/rendering.py index 83ec6e77b..b94d736ef 100644 --- a/gsplat/rendering.py +++ b/gsplat/rendering.py @@ -582,6 +582,218 @@ def reshape_view(C: int, world_view: torch.Tensor, N_world: list) -> torch.Tenso return render_colors, render_alphas, meta +def view_to_visible_anchors( + means: Tensor, # [N, 3] + quats: Tensor, # [N, 4] + scales: Tensor, # [N, 3] + viewmats: Tensor, # [C, 4, 4] + Ks: Tensor, # [C, 3, 3] + width: int, + height: int, + near_plane: float = 0.01, + far_plane: float = 1e10, + radius_clip: float = 0.0, + eps2d: float = 0.3, + sh_degree: Optional[int] = None, + packed: bool = True, + rasterize_mode: Literal["classic", "antialiased"] = "classic", + camera_model: Literal["pinhole", "ortho", "fisheye"] = "pinhole", + covars: Optional[Tensor] = None, +) -> Tuple[Tensor, Tensor, Dict]: + """Rasterize a set of 3D Gaussians (N) to a batch of image planes (C). + + This function provides a handful features for 3D Gaussian rasterization, which + we detail in the following notes. A complete profiling of the these features + can be found in the :ref:`profiling` page. + + .. note:: + **Multi-GPU Distributed Rasterization**: This function can be used in a multi-GPU + distributed scenario by setting `distributed` to True. When `distributed` is True, + a subset of total Gaussians could be passed into this function in each rank, and + the function will collaboratively render a set of images using Gaussians from all ranks. Note + to achieve balanced computation, it is recommended (not enforced) to have similar number of + Gaussians in each rank. But we do enforce that the number of cameras to be rendered + in each rank is the same. The function will return the rendered images + corresponds to the input cameras in each rank, and allows for gradients to flow back to the + Gaussians living in other ranks. For the details, please refer to the paper + `On Scaling Up 3D Gaussian Splatting Training `_. + + .. note:: + **Batch Rasterization**: This function allows for rasterizing a set of 3D Gaussians + to a batch of images in one go, by simplly providing the batched `viewmats` and `Ks`. + + .. note:: + **Support N-D Features**: If `sh_degree` is None, + the `colors` is expected to be with shape [N, D] or [C, N, D], in which D is the channel of + the features to be rendered. The computation is slow when D > 32 at the moment. + If `sh_degree` is set, the `colors` is expected to be the SH coefficients with + shape [N, K, 3] or [C, N, K, 3], where K is the number of SH bases. In this case, it is expected + that :math:`(\\textit{sh_degree} + 1) ^ 2 \\leq K`, where `sh_degree` controls the + activated bases in the SH coefficients. + + .. note:: + **Depth Rendering**: This function supports colors or/and depths via `render_mode`. + The supported modes are "RGB", "D", "ED", "RGB+D", and "RGB+ED". "RGB" renders the + colored image that respects the `colors` argument. "D" renders the accumulated z-depth + :math:`\\sum_i w_i z_i`. "ED" renders the expected z-depth + :math:`\\frac{\\sum_i w_i z_i}{\\sum_i w_i}`. "RGB+D" and "RGB+ED" render both + the colored image and the depth, in which the depth is the last channel of the output. + + .. note:: + **Memory-Speed Trade-off**: The `packed` argument provides a trade-off between + memory footprint and runtime. If `packed` is True, the intermediate results are + packed into sparse tensors, which is more memory efficient but might be slightly + slower. This is especially helpful when the scene is large and each camera sees only + a small portion of the scene. If `packed` is False, the intermediate results are + with shape [C, N, ...], which is faster but might consume more memory. + + .. note:: + **Sparse Gradients**: If `sparse_grad` is True, the gradients for {means, quats, scales} + will be stored in a `COO sparse layout `_. + This can be helpful for saving memory + for training when the scene is large and each iteration only activates a small portion + of the Gaussians. Usually a sparse optimizer is required to work with sparse gradients, + such as `torch.optim.SparseAdam `_. + This argument is only effective when `packed` is True. + + .. note:: + **Speed-up for Large Scenes**: The `radius_clip` argument is extremely helpful for + speeding up large scale scenes or scenes with large depth of fields. Gaussians with + 2D radius smaller or equal than this value (in pixel unit) will be skipped during rasterization. + This will skip all the far-away Gaussians that are too small to be seen in the image. + But be warned that if there are close-up Gaussians that are also below this threshold, they will + also get skipped (which is rarely happened in practice). This is by default disabled by setting + `radius_clip` to 0.0. + + .. note:: + **Antialiased Rendering**: If `rasterize_mode` is "antialiased", the function will + apply a view-dependent compensation factor + :math:`\\rho=\\sqrt{\\frac{Det(\\Sigma)}{Det(\\Sigma+ \\epsilon I)}}` to Gaussian + opacities, where :math:`\\Sigma` is the projected 2D covariance matrix and :math:`\\epsilon` + is the `eps2d`. This will make the rendered image more antialiased, as proposed in + the paper `Mip-Splatting: Alias-free 3D Gaussian Splatting `_. + + .. note:: + **AbsGrad**: If `absgrad` is True, the absolute gradients of the projected + 2D means will be computed during the backward pass, which could be accessed by + `meta["means2d"].absgrad`. This is an implementation of the paper + `AbsGS: Recovering Fine Details for 3D Gaussian Splatting `_, + which is shown to be more effective for splitting Gaussians during training. + + .. warning:: + This function is currently not differentiable w.r.t. the camera intrinsics `Ks`. + + Args: + means: The 3D centers of the Gaussians. [N, 3] + quats: The quaternions of the Gaussians (wxyz convension). It's not required to be normalized. [N, 4] + scales: The scales of the Gaussians. [N, 3] + viewmats: The world-to-cam transformation of the cameras. [C, 4, 4] + Ks: The camera intrinsics. [C, 3, 3] + width: The width of the image. + height: The height of the image. + near_plane: The near plane for clipping. Default is 0.01. + far_plane: The far plane for clipping. Default is 1e10. + radius_clip: Gaussians with 2D radius smaller or equal than this value will be + skipped. This is extremely helpful for speeding up large scale scenes. + Default is 0.0. + eps2d: An epsilon added to the egienvalues of projected 2D covariance matrices. + This will prevents the projected GS to be too small. For example eps2d=0.3 + leads to minimal 3 pixel unit. Default is 0.3. + sh_degree: The SH degree to use, which can be smaller than the total + number of bands. If set, the `colors` should be [(C,) N, K, 3] SH coefficients, + else the `colors` should [(C,) N, D] post-activation color values. Default is None. + packed: Whether to use packed mode which is more memory efficient but might or + might not be as fast. Default is True. + tile_size: The size of the tiles for rasterization. Default is 16. + (Note: other values are not tested) + backgrounds: The background colors. [C, D]. Default is None. + render_mode: The rendering mode. Supported modes are "RGB", "D", "ED", "RGB+D", + and "RGB+ED". "RGB" renders the colored image, "D" renders the accumulated depth, and + "ED" renders the expected depth. Default is "RGB". + sparse_grad: If true, the gradients for {means, quats, scales} will be stored in + a COO sparse layout. This can be helpful for saving memory. Default is False. + absgrad: If true, the absolute gradients of the projected 2D means + will be computed during the backward pass, which could be accessed by + `meta["means2d"].absgrad`. Default is False. + rasterize_mode: The rasterization mode. Supported modes are "classic" and + "antialiased". Default is "classic". + channel_chunk: The number of channels to render in one go. Default is 32. + If the required rendering channels are larger than this value, the rendering + will be done looply in chunks. + distributed: Whether to use distributed rendering. Default is False. If True, + The input Gaussians are expected to be a subset of scene in each rank, and + the function will collaboratively render the images for all ranks. + ortho: Whether to use orthographic projection. In such case fx and fy become the scaling + factors to convert projected coordinates into pixel space and cx, cy become offsets. + covars: Optional covariance matrices of the Gaussians. If provided, the `quats` and + `scales` will be ignored. [N, 3, 3], Default is None. + + Returns: + A tuple: + + **render_colors**: The rendered colors. [C, height, width, X]. + X depends on the `render_mode` and input `colors`. If `render_mode` is "RGB", + X is D; if `render_mode` is "D" or "ED", X is 1; if `render_mode` is "RGB+D" or + "RGB+ED", X is D+1. + + **render_alphas**: The rendered alphas. [C, height, width, 1]. + + **meta**: A dictionary of intermediate results of the rasterization. + + """ + N = means.shape[0] + C = viewmats.shape[0] + assert means.shape == (N, 3), means.shape + if covars is None: + assert quats.shape == (N, 4), quats.shape + assert scales.shape == (N, 3), scales.shape + else: + assert covars.shape == (N, 3, 3), covars.shape + quats, scales = None, None + # convert covars from 3x3 matrix to upper-triangular 6D vector + tri_indices = ([0, 0, 0, 1, 1, 2], [0, 1, 2, 1, 2, 2]) + covars = covars[..., tri_indices[0], tri_indices[1]] + assert viewmats.shape == (C, 4, 4), viewmats.shape + assert Ks.shape == (C, 3, 3), Ks.shape + + # Project Gaussians to 2D. Directly pass in {quats, scales} is faster than precomputing covars. + proj_results = fully_fused_projection( + means, + covars, + quats, + scales, + viewmats, + Ks, + width, + height, + eps2d=eps2d, + packed=packed, + near_plane=near_plane, + far_plane=far_plane, + radius_clip=radius_clip, + sparse_grad=False, + calc_compensations=(rasterize_mode == "antialiased"), + camera_model=camera_model, + ) + + if packed: + # The results are packed into shape [nnz, ...]. All elements are valid. + ( + _, + _, + radii, + _, + _, + _, + _, + ) = proj_results + else: + # The results are with shape [C, N, ...]. Only the elements with radii > 0 are valid. + radii, _, _, _, _ = proj_results + + return radii.squeeze(0) > 0 + + def _rasterization( means: Tensor, # [N, 3] quats: Tensor, # [N, 4] @@ -1447,7 +1659,7 @@ def rasterization_2dgs_inria_wrapper( render_depth_expected * (1 - depth_ratio) + (depth_ratio) * render_depth_median ) - normals_surf = depth_to_normal(render_depth, torch.linalg.inv(viewmats), Ks) + normals_surf = depth_to_normal(render_depth, viewmats, Ks) normals_surf = normals_surf * (render_alphas).detach() render_colors = torch.cat([render_colors, render_depth], dim=-1) diff --git a/gsplat/strategy/__init__.py b/gsplat/strategy/__init__.py index 305dc8129..08ac72b8f 100644 --- a/gsplat/strategy/__init__.py +++ b/gsplat/strategy/__init__.py @@ -1,3 +1,4 @@ from .base import Strategy from .default import DefaultStrategy from .mcmc import MCMCStrategy +from .scaffold import ScaffoldStrategy diff --git a/gsplat/strategy/ops.py b/gsplat/strategy/ops.py index 0789dfcbc..31affe6c2 100644 --- a/gsplat/strategy/ops.py +++ b/gsplat/strategy/ops.py @@ -180,13 +180,16 @@ def remove( optimizers: Dict[str, torch.optim.Optimizer], state: Dict[str, Tensor], mask: Tensor, + names: Union[List[str], None] = None, ): """Inplace remove the Gaussian with the given mask. Args: params: A dictionary of parameters. optimizers: A dictionary of optimizers, each corresponding to a parameter. + state: A dictionary of extra state tensors. mask: A boolean mask to remove the Gaussians. + names: A list of key names to update. If None, update all. Default: None. """ sel = torch.where(~mask)[0] @@ -197,7 +200,7 @@ def optimizer_fn(key: str, v: Tensor) -> Tensor: return v[sel] # update the parameters and the state in the optimizers - _update_param_with_optimizer(param_fn, optimizer_fn, params, optimizers) + _update_param_with_optimizer(param_fn, optimizer_fn, params, optimizers, names) # update the extra running state for k, v in state.items(): if isinstance(v, torch.Tensor): @@ -265,7 +268,7 @@ def relocate( sampled_idxs = alive_indices[sampled_idxs] new_opacities, new_scales = compute_relocation( opacities=opacities[sampled_idxs], - scales=torch.exp(params["scales"])[sampled_idxs], + scales=torch.exp(params["scales"][:, :3])[sampled_idxs], ratios=torch.bincount(sampled_idxs)[sampled_idxs] + 1, binoms=binoms, ) @@ -275,20 +278,26 @@ def param_fn(name: str, p: Tensor) -> Tensor: if name == "opacities": p[sampled_idxs] = torch.logit(new_opacities) elif name == "scales": - p[sampled_idxs] = torch.log(new_scales) + p[sampled_idxs][:, :3] = torch.log(new_scales) p[dead_indices] = p[sampled_idxs] return torch.nn.Parameter(p) def optimizer_fn(key: str, v: Tensor) -> Tensor: v[sampled_idxs] = 0 + v[dead_indices] = 0 return v # update the parameters and the state in the optimizers _update_param_with_optimizer(param_fn, optimizer_fn, params, optimizers) # update the extra running state for k, v in state.items(): - if isinstance(v, torch.Tensor): - v[sampled_idxs] = 0 + if isinstance(v, torch.Tensor) and k != "binoms": + if k == "anchor_count" or k == "anchor_opacity": + v[sampled_idxs] = 0 + v[dead_indices] = 0 + else: + v.view(-1, state["n_feat_offsets"])[sampled_idxs] = 0 + v.view(-1, state["n_feat_offsets"])[dead_indices] = 0 @torch.no_grad() @@ -359,5 +368,147 @@ def op_sigmoid(x, k=100, x0=0.995): * (op_sigmoid(1 - opacities)).unsqueeze(-1) * scaler ) - noise = torch.einsum("bij,bj->bi", covars, noise) params["means"].add_(noise) + + +@torch.no_grad() +def grow_anchors( + params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict], + optimizers: Dict[str, torch.optim.Optimizer], + state: Dict[str, Tensor], + anchors: torch.Tensor, + gradient_mask: torch.Tensor, + remove_duplicates_mask: torch.Tensor, + inv_idx: torch.Tensor, + voxel_size: float, + n_feat_offsets: int, + feat_dim: int, +): + """Inplace add new Gaussians (anchors) to the parameters. + + Args: + params: A dictionary of parameters. + optimizers: A dictionary of optimizers, each corresponding to a parameter. + state: A dictionary of extra state tensors. + anchors: Positions of new anchors to be added. + gradient_mask: A mask to select gradients. + remove_duplicates_mask: A mask to remove duplicates. + inv_idx: Indices for inverse mapping. + voxel_size: The size of the voxel. + n_feat_offsets: Number of feature offsets. + feat_dim: Dimension of features. + """ + device = anchors.device + num_new = anchors.size(0) + + # Scale anchors + anchors = anchors * voxel_size # [N_new, 3] + + # Initialize new parameters + log_voxel_size = torch.log(torch.tensor(voxel_size, device=device)) + scaling = log_voxel_size.expand(num_new, anchors.size(1) * 2) # [N_new, 6] + + rotation = torch.ones((num_new, 4), device=device) + + # Prepare new features + existing_features = params["features"] # [N_existing, feat_dim] + repeated_features = ( + existing_features.unsqueeze(1) + .expand(-1, n_feat_offsets, -1) + .reshape(-1, existing_features.shape[1]) + ) # [N_existing * n_feat_offsets, feat_dim] + + selected_features = repeated_features[gradient_mask] # [N_selected, feat_dim] + + # Use inverse_indices to aggregate features + scattered_features = torch.segment_reduce( + data=selected_features, reduce="amax", lengths=torch.bincount(inv_idx) + ) + feat = scattered_features[remove_duplicates_mask] # [N_new, feat_dim] + + def inverse_sigmoid(x): + return torch.log(x / (1 - x)) + + opacities = inverse_sigmoid( + 0.1 * torch.ones((anchors.shape[0], 1), dtype=torch.float, device="cuda") + ) + # Initialize new offsets + offsets = torch.zeros( + (num_new, n_feat_offsets, 3), device=device + ) # [N_new, n_feat_offsets, 3] + + def param_fn(name: str, p: Tensor) -> Tensor: + if name == "anchors": + p_new = torch.cat([p, anchors], dim=0) + elif name == "scales": + p_new = torch.cat([p, scaling], dim=0) + elif name == "quats": + p_new = torch.cat([p, rotation], dim=0) + elif name == "features": + p_new = torch.cat([p, feat], dim=0) + elif name == "offsets": + p_new = torch.cat([p, offsets], dim=0) + elif name == "opacities": + p_new = torch.cat([p, opacities], dim=0) + else: + raise ValueError(f"Parameter '{name}' not recognized.") + return torch.nn.Parameter(p_new) + + def optimizer_fn(key: str, v: Tensor) -> Tensor: + # Extend optimizer state tensors with zeros + zeros = torch.zeros((num_new, *v.shape[1:]), device=device) + v_new = torch.cat([v, zeros], dim=0) + return v_new + + # Update parameters and optimizer states + _update_param_with_optimizer(param_fn, optimizer_fn, params, optimizers) + + # Update the extra running state + for k, v in state.items(): + if isinstance(v, torch.Tensor) and k != "binoms": + if k == "anchor_count" or k == "anchor_opacity": + zeros = torch.zeros((num_new, *v.shape[1:]), device=device) + else: + zeros = torch.zeros( + (num_new * n_feat_offsets, *v.shape[1:]), device=device + ) + state[k] = torch.cat([v, zeros], dim=0) + + +@torch.no_grad() +def remove_anchors( + params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict], + optimizers: Dict[str, torch.optim.Optimizer], + n_feat_offsets: int, + state: Dict[str, Tensor], + mask: Tensor, + names: Union[List[str], None] = None, +): + """Inplace remove the Gaussian with the given mask. + + Args: + params: A dictionary of parameters. + optimizers: A dictionary of optimizers, each corresponding to a parameter. + n_feat_offsets: Number of feature offsets. + state: A dictionary of extra state tensors. + mask: A boolean mask to remove the Gaussians. + names: A list of parameter names to update. If None, update all. Default: None. + """ + sel = torch.where(~mask)[0] + + def param_fn(name: str, p: Tensor) -> Tensor: + return torch.nn.Parameter(p[sel]) + + def optimizer_fn(key: str, v: Tensor) -> Tensor: + return v[sel] + + # update the parameters and the state in the optimizers + _update_param_with_optimizer(param_fn, optimizer_fn, params, optimizers, names) + # update the extra running state + for k, v in state.items(): + if isinstance(v, torch.Tensor) and k != "binoms": + if k in ["anchor_count", "anchor_opacity"]: + state[k] = v[sel] + else: + offset_sel = sel.unsqueeze(dim=1).repeat([1, n_feat_offsets]).view(-1) + state[k] = v[offset_sel] diff --git a/gsplat/strategy/scaffold.py b/gsplat/strategy/scaffold.py new file mode 100644 index 000000000..d331172c9 --- /dev/null +++ b/gsplat/strategy/scaffold.py @@ -0,0 +1,491 @@ +from dataclasses import dataclass +from typing import Any, Dict, Union +import torch + +from .base import Strategy +from .ops import remove_anchors, grow_anchors +from .ops import relocate +import math + + +@dataclass +class ScaffoldStrategy(Strategy): + """A neural gaussian strategy that follows the paper: + + `Scaffold-GS: Structured 3D Gaussians for View-Adaptive Rendering `_ + + The strategy will: + + - Periodically grows anchors with high image plane gradients. + - Periodically teleport anchors with low opacity to a place that has high opacity. + + If `absgrad=True`, it will use the absolute gradients instead of average gradients + for GS duplicating & splitting, following the AbsGS paper: + + `AbsGS: Recovering Fine Details for 3D Gaussian Splatting `_ + + Which typically leads to better results but requires to set the `grow_grad2d` to a + higher value, e.g., 0.0008. Also, the :func:`rasterization` function should be called + with `absgrad=True` as well so that the absolute gradients are computed. + + Args: + prune_opa (float): Threshold for pruning GSs with opacity below this value. Default is 0.005. + grow_grad2d (float): Threshold for splitting/duplicating GSs based on image plane gradient. Default is 0.0002. + refine_start_iter (int): Iteration to start refining GSs. Default is 500. + refine_stop_iter (int): Iteration to stop refining GSs. Default is 15,000. + refine_every (int): Frequency (in steps) at which GSs are refined. Default is 100. + absgrad (bool): Whether to use absolute gradients for GS splitting. Default is False. + verbose (bool): If True, prints detailed information during refinement. Default is False. + max_voxel_levels (int): Maximum levels for voxel splitting during GS growth. Default is 3. + voxel_size (float): Base size of the voxel used in GS growth. Default is 0.001. + pruning_thresholds (float): Threshold for pruning based on refinement steps. Default is 0.8. + growing_thresholds (float): Threshold for GS growth based on refinement steps. Default is 0.4. + drop_out (bool): If True, applies dropout during GS growth to prevent overgrowth. Default is False. + pruning (bool): If True, enables pruning of GSs during refinement. Default is False. + + Examples: + + >>> from gsplat import ScaffoldStrategy, rasterization + >>> params: Dict[str, torch.nn.Parameter] | torch.nn.ParameterDict = ... + >>> optimizers: Dict[str, torch.optim.Optimizer] = ... + >>> strategy = ScaffoldStrategy() + >>> strategy.check_sanity(params, optimizers) + >>> strategy_state = strategy.initialize_state() + >>> for step in range(1000): + ... render_image, render_alpha, info = rasterization(...) + ... strategy.step_pre_backward(params, optimizers, strategy_state, step, info) + ... loss = ... + ... loss.backward() + ... strategy.step_post_backward(params, optimizers, strategy_state, step, info) + + """ + + prune_opa: float = 0.005 + grow_grad2d: float = 1.28e-4 + refine_start_iter: int = 800 + absgrad: bool = False + max_voxel_levels: int = 3 + voxel_size: float = 0.001 + refine_stop_iter: int = 15_000 + refine_every: int = 100 + + # 3.3 Observation Thresholds (compare paper) + # "To enhance the robustness of the Growing and Pruning operations for long image sequences, ..." + pruning_thresholds: float = 0.8 + growing_thresholds: float = 0.4 + + verbose: bool = True + drop_out: bool = False + pruning: bool = False + + def initialize_state( + self, scene_scale: float = 1.0, feat_dim=128, n_feat_offsets=10 + ) -> Dict[str, Any]: + """Initialize and return the running state for this strategy. + + The returned state should be passed to the `step_pre_backward()` and + `step_post_backward()` functions. + """ + # Postpone the initialization of the state to the first step so that we can + # put them on the correct device. + # - grad2d: running accum of the norm of the image plane gradients for each GS. + # - count: running accum of how many time each GS is visible. + + n_max = 51 + binoms = torch.zeros((n_max, n_max)) + for n in range(n_max): + for k in range(n + 1): + binoms[n, k] = math.comb(n, k) + state = { + "binoms": binoms, + "grad2d": None, + "count": None, + "anchor_count": None, + "anchor_opacity": None, + "scene_scale": scene_scale, + "feat_dim": feat_dim, + "n_feat_offsets": n_feat_offsets, + } + return state + + def check_sanity( + self, + params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict], + optimizers: Dict[str, torch.optim.Optimizer], + ): + """Sanity check for the parameters and optimizers. + + Check if: + * `params` and `optimizers` have the same keys. + * Each optimizer has exactly one param_group, corresponding to each parameter. + * The following keys are present: {"anchors", "scales", "quats", "opacities"}. + + Raises: + AssertionError: If any of the above conditions is not met. + + .. note:: + It is not required but highly recommended for the user to call this function + after initializing the strategy to ensure the convention of the parameters + and optimizers is as expected. + """ + + super().check_sanity(params, optimizers) + # The following keys are required for this strategy. + expected_params = [ + "anchors", + "features", + "offsets", + "scales", + "quats", + "opacities", + ] + + assert len(expected_params) == len( + params + ), "expected params and actual params don't match" + for key in expected_params: + assert key in params, f"{key} is required in params but missing." + + def step_pre_backward( + self, + params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict], + optimizers: Dict[str, torch.optim.Optimizer], + state: Dict[str, Any], + step: int, + info: Dict[str, Any], + ): + """Callback function to be executed before the `loss.backward()` call.""" + assert ( + "means2d" in info + ), "The 2D anchors of the Gaussians is required but missing." + info["means2d"].retain_grad() + + def step_post_backward( + self, + params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict], + optimizers: Dict[str, torch.optim.Optimizer], + state: Dict[str, Any], + step: int, + info: Dict[str, Any], + packed: bool = False, + ): + """Callback function to be executed after the `loss.backward()` call.""" + + if step >= self.refine_stop_iter: + return + + # move to the correct device + state["binoms"] = state["binoms"].to(params["anchors"].device) + binoms = state["binoms"] + + if step > 500: + self._update_state(params, state, info, packed=packed) + + if step > self.refine_start_iter and step % self.refine_every == 0: + # grow GSs + new_anchors = self._anchor_growing(params, optimizers, state) + if self.verbose: + print( + f"Step {step}: {new_anchors} anchors grown. Now having {len(params['anchors'])} anchors." + ) + + if self.pruning: + low_opacity_mask = ( + state["anchor_opacity"] < self.prune_opa * state["anchor_count"] + ).squeeze() + anchor_mask = ( + state["anchor_count"] > self.pruning_thresholds * self.refine_every + ) # [N, 1] + is_prune = torch.logical_and(low_opacity_mask, anchor_mask) + + indices = is_prune.nonzero(as_tuple=False).squeeze() + # Efficiently set the specified indices to zero + state["anchor_count"].index_fill_(0, indices, 0) + state["anchor_opacity"].index_fill_(0, indices, 0) + + n_prune = is_prune.sum().item() + if n_prune > 0: + remove_anchors( + params=params, + optimizers=optimizers, + n_feat_offsets=state["n_feat_offsets"], + state=state, + mask=is_prune, + ) + + if self.verbose: + print( + f"Step {step}: {n_prune} anchors pruned. Now having {len(params['anchors'])} anchors." + ) + else: + n_relocated_gs = self._relocate_gs(params, optimizers, binoms, state) + if self.verbose: + print(f"Relocated anchors {n_relocated_gs}") + + torch.cuda.empty_cache() + + def _update_state( + self, + params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict], + state: Dict[str, Any], + info: Dict[str, Any], + packed: bool = False, + ): + # Ensure required keys are present + required_keys = ["width", "height", "n_cameras", "radii", "gaussian_ids"] + for key in required_keys: + assert key in info, f"{key} is required but missing." + + # Normalize gradients to [-1, 1] screen space + factor = 0.5 * info["n_cameras"] + scale_factors = torch.tensor( + [info["width"] * factor, info["height"] * factor], + device=info["means2d"].device, + ) + + if self.absgrad: + grads = info["means2d"].absgrad.detach() * scale_factors + else: + grads = info["means2d"].grad.detach() * scale_factors + + # Initialize state on the first run + n_gaussian = params["anchors"].shape[0] + device = grads.device + if state["grad2d"] is None: + state["grad2d"] = torch.zeros( + n_gaussian * state["n_feat_offsets"], device=device + ) + if state["count"] is None: + state["count"] = torch.zeros( + n_gaussian * state["n_feat_offsets"], device=device + ) + if state["anchor_count"] is None: + state["anchor_count"] = torch.zeros(n_gaussian, device=device) + if state["anchor_opacity"] is None: + state["anchor_opacity"] = torch.zeros(n_gaussian, device=device) + + neural_selection_mask = info["neural_selection_mask"] + + # Update the running state + sel = info["radii"] > 0.0 # [C, N] + neural_ids = neural_selection_mask.nonzero(as_tuple=False).squeeze(-1) + grads = grads[sel].norm(dim=-1) # [nnz, 2] + sel = sel.squeeze(0) + + valid_ids = neural_ids[sel] + + # Update state using index_add_ + state["grad2d"].index_add_(0, valid_ids, grads) + state["count"].index_add_( + 0, + valid_ids, + torch.ones((valid_ids.shape[0]), dtype=torch.float32, device=device), + ) + + # Update anchor opacity and count + visible_anchor_mask = info["visible_anchor_mask"] + anchor_ids = visible_anchor_mask.nonzero(as_tuple=False).squeeze(-1) + neural_opacities = ( + info["neural_opacities"] + .detach() + .view(-1, state["n_feat_offsets"]) + .clamp_min_(0) + .sum(dim=1) + ) + + state["anchor_opacity"].index_add_(0, anchor_ids, neural_opacities) + state["anchor_count"].index_add_( + 0, anchor_ids, torch.ones_like(anchor_ids, dtype=torch.float32) + ) + + @torch.no_grad() + def _anchor_growing( + self, + params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict], + optimizers: Dict[str, torch.optim.Optimizer], + state: Dict[str, Any], + ) -> int: + """ + Implements the Anchor Growing algorithm as described in Algorithm 1 of the + GS-Scaffold appendix: + https://openaccess.thecvf.com/content/CVPR2024/supplemental/Lu_Scaffold-GS_Structured_3D_CVPR_2024_supplemental.pdf + + This method performs anchor growing for structured optimization during training, + which helps improve the generalization and stability of the model. + + Args: + params (Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict]): + The model's parameters to be optimized. + optimizers (Dict[str, torch.optim.Optimizer]): + A dictionary of optimizers associated with the model's parameters. + state (Dict[str, Any]): + A dictionary containing the current state of the training process. + + Returns: + int: + Number of growned anchors. + """ + + count = state["count"] + grads = state["grad2d"] / count.clamp_min(1) + grads[grads.isnan()] = 0.0 + device = grads.device + + n_init_features = state["count"].shape[0] + n_added_anchors = 0 + + # Algorithm 1: Anchor Growing + # Step 1: Initialization + m = 1 # Iteration count (levels) + M = self.max_voxel_levels + tau_g = self.grow_grad2d + epsilon_g = self.voxel_size + new_anchors: torch.tensor + + growing_threshold_mask = ( + state["count"] > self.refine_every * self.growing_thresholds + ) + # Step 2: Iterate while m <= M + while m <= M: + n_feature_diff = state["count"].shape[0] - n_init_features + # Check if anchor candidates have grown + if n_feature_diff == 0 and m > 1: + break + + # Step 3: Update threshold and voxel size + tau = tau_g * (2 ** (m - 1)) + current_voxel_size = (16 // (4 ** (m - 1))) * epsilon_g + + # Step 4: Mask from grad threshold. Select neural gaussians (Select candidates) + gradient_mask = grads >= tau + gradient_mask = torch.logical_and(gradient_mask, growing_threshold_mask) + + # Drop-out: Helps prevent too massive anchor growth. + if self.drop_out: + rand_mask = torch.rand_like(gradient_mask.float()) > (0.5**m) + rand_mask = rand_mask.cuda() + gradient_mask = torch.logical_and(gradient_mask, rand_mask) + + gradient_mask = torch.cat( + [ + gradient_mask, + torch.zeros(n_feature_diff, dtype=torch.bool, device=device), + ], + dim=0, + ) + + # Compute neural gaussians + neural_gaussians = params["anchors"].unsqueeze(dim=1) + params[ + "offsets" + ] * torch.exp(params["scales"][:, :3]).unsqueeze(dim=1) + selected_neural_gaussians = neural_gaussians.view([-1, 3])[gradient_mask] + + # Step 5: Merge same positions + selected_grid_coords = torch.round( + selected_neural_gaussians / current_voxel_size + ).int() + selected_grid_coords_unique, inv_idx = torch.unique( + selected_grid_coords, return_inverse=True, dim=0 + ) + + # Get the grid coordinates of the current anchors + grid_anchor_coords = torch.round( + params["anchors"] / current_voxel_size + ).int() + + # Step 6: Remove occupied by comparing the unique coordinates to current anchors + def coords_to_indices(coords, N): + """ + Maps quantized multi-dimensional coordinates to unique 1D indices without using torch.matmul. + + Args: + coords (torch.Tensor): Tensor of shape [num_points, D], where D is the dimensionality. + N (int): A large enough number to ensure uniqueness of indices. + + Returns: + torch.Tensor: Tensor of unique indices corresponding to the coordinates. + """ + D = coords.shape[1] + device = coords.device + dtype = coords.dtype # Keep the original data type + + # Compute N_powers as integers + N_powers = N ** torch.arange(D - 1, -1, -1, device=device, dtype=dtype) + + # Perform element-wise multiplication and sum along the last dimension + indices = (coords * N_powers).sum(dim=1) + + return indices + + # Quantize the coordinates + decimal_places = 6 # precision + scale = 10**decimal_places + + # Quantize the coordinates by scaling and converting to integers + selected_coords_quant = (selected_grid_coords_unique * scale).long() + anchor_coords_quant = (grid_anchor_coords * scale).long() + + # Compute the maximum coordinate value + max_coord_value = ( + torch.max(torch.cat([selected_coords_quant, anchor_coords_quant])) + 1 + ) + N = max_coord_value.item() + + # Compute unique indices for both coordinate sets + indices_selected = coords_to_indices(selected_coords_quant, N) + indices_anchor = coords_to_indices(anchor_coords_quant, N) + + remove_occupied_pos_mask = ~torch.isin(indices_selected, indices_anchor) + + # New anchor candidates are those unique coordinates that are not duplicates + new_anchors = selected_grid_coords_unique[remove_occupied_pos_mask] + + if new_anchors.shape[0] > 0: + grow_anchors( + params=params, + optimizers=optimizers, + state=state, + anchors=new_anchors, + gradient_mask=gradient_mask, + remove_duplicates_mask=remove_occupied_pos_mask, + inv_idx=inv_idx, + voxel_size=current_voxel_size, + n_feat_offsets=state["n_feat_offsets"], + feat_dim=state["feat_dim"], + ) + + n_added_anchors += new_anchors.shape[0] + m += 1 + + indices = torch.arange(n_init_features, device=growing_threshold_mask.device)[ + growing_threshold_mask + ] + + if indices.numel() > 0: + state["count"].index_fill_(0, indices, 0) + state["grad2d"].index_fill_(0, indices, 0) + + return n_added_anchors + + @torch.no_grad() + def _relocate_gs( + self, + params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict], + optimizers: Dict[str, torch.optim.Optimizer], + binoms: torch.Tensor, + state: Dict[str, Any], + ) -> int: + dead_mask = ( + state["anchor_opacity"] < self.prune_opa * state["anchor_count"] + ).squeeze() + n_gs = dead_mask.sum().item() + if n_gs > 0: + relocate( + params=params, + optimizers=optimizers, + state=state, + mask=dead_mask, + binoms=binoms, + min_opacity=self.prune_opa, + ) + return n_gs