diff --git a/examples/simple_trainer_scaffold.py b/examples/simple_trainer_scaffold.py new file mode 100644 index 000000000..300842750 --- /dev/null +++ b/examples/simple_trainer_scaffold.py @@ -0,0 +1,1069 @@ +import json +import math +import os +import time +from dataclasses import dataclass, field +from collections import defaultdict +from typing import Dict, List, Optional, Tuple, Union + +import imageio +import nerfview +import numpy as np +import torch +import torch.nn.functional as F +import tqdm +import tyro +import viser +import yaml +from datasets.colmap import Dataset, Parser +from datasets.traj import ( + generate_interpolated_path, + generate_ellipse_path_z, + generate_spiral_path, +) +from torch import Tensor +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure +from fused_ssim import fused_ssim +from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity +from typing_extensions import Literal +from utils import AppearanceOptModule, CameraOptModule, knn, set_random_seed +from lib_bilagrid import ( + BilateralGrid, + slice, + color_correct, + total_variation_loss, +) + +from gsplat.compression import PngCompression +from gsplat.distributed import cli +from gsplat.rendering import rasterization, filter_visible_gaussians +from gsplat.strategy import ScaffoldStrategy + + +@dataclass +class Config: + # Disable viewer + disable_viewer: bool = False + # Path to the .pt files. If provide, it will skip training and run evaluation only. + ckpt: Optional[List[str]] = None + # Name of compression strategy to use + compression: Optional[Literal["png"]] = None + # Render trajectory path + render_traj_path: str = "interp" + + # Path to the Mip-NeRF 360 dataset + data_dir: str = "/home/paja/data/bike" + # Downsample factor for the dataset + data_factor: int = 1 + # Directory to save results + result_dir: str = "results" + # Every N images there is a test image + test_every: int = 8 + # Random crop size for training (experimental) + patch_size: Optional[int] = None + # A global scaler that applies to the scene size related parameters + global_scale: float = 1.0 + # Normalize the world space + normalize_world_space: bool = True + + # Port for the viewer server + port: int = 8080 + + # Batch size for training. Learning rates are scaled automatically + batch_size: int = 1 + # A global factor to scale the number of training steps + steps_scaler: float = 1.0 + + # Number of training steps + max_steps: int = 30_000 + # Steps to evaluate the model + eval_steps: List[int] = field(default_factory=lambda: [7_000, 30_000]) + # Steps to save the model + save_steps: List[int] = field(default_factory=lambda: [7_000, 30_000]) + + # voxel size for Scaffold-GS + voxel_size = 0.001 + # Initial extent of GSs as a multiple of the camera extent. Ignored if using sfm + init_extent: float = 3.0 + # Turn on another SH degree every this steps + sh_degree_interval: int = 1000 + # Initial opacity of GS + init_opa: float = 0.1 + # Initial scale of GS + init_scale: float = 1.0 + # Weight for SSIM loss + ssim_lambda: float = 0.2 + + # Near plane clipping distance + near_plane: float = 0.01 + # Far plane clipping distance + far_plane: float = 1e10 + + # Strategy for GS densification + strategy: ScaffoldStrategy = field(default_factory=ScaffoldStrategy) + # Use packed mode for rasterization, this leads to less memory usage but slightly slower. + packed: bool = False + # Use sparse gradients for optimization. (experimental) + sparse_grad: bool = False + # Anti-aliasing in rasterization. Might slightly hurt quantitative metrics. + antialiased: bool = False + + # Use random background for training to discourage transparency + random_bkgd: bool = False + + # Opacity regularization + opacity_reg: float = 0.0 + # Scale regularization + scale_reg: float = 0.01 + + # Enable camera optimization. + pose_opt: bool = False + # Learning rate for camera optimization + pose_opt_lr: float = 1e-5 + # Regularization for camera optimization as weight decay + pose_opt_reg: float = 1e-6 + # Add noise to camera extrinsics. This is only to test the camera pose optimization. + pose_noise: float = 0.0 + + # Enable appearance optimization. (experimental) + app_opt: bool = False + # Appearance embedding dimension + app_embed_dim: int = 16 + # Learning rate for appearance optimization + app_opt_lr: float = 1e-3 + # Regularization for appearance optimization as weight decay + app_opt_reg: float = 1e-6 + + # Enable bilateral grid. (experimental) + use_bilateral_grid: bool = False + # Shape of the bilateral grid (X, Y, W) + bilateral_grid_shape: Tuple[int, int, int] = (16, 16, 8) + + # Enable depth loss. (experimental) + depth_loss: bool = False + # Weight for depth loss + depth_lambda: float = 1e-2 + + # Dump information to tensorboard every this steps + tb_every: int = 100 + # Save training images to tensorboard + tb_save_image: bool = False + + lpips_net: Literal["vgg", "alex"] = "alex" + + def adjust_steps(self, factor: float): + self.eval_steps = [int(i * factor) for i in self.eval_steps] + self.save_steps = [int(i * factor) for i in self.save_steps] + self.max_steps = int(self.max_steps * factor) + self.sh_degree_interval = int(self.sh_degree_interval * factor) + + strategy = self.strategy + strategy.refine_start_iter = int(strategy.refine_start_iter * factor) + strategy.refine_stop_iter = int(strategy.refine_stop_iter * factor) + # strategy.reset_every = int(strategy.reset_every * factor) + # strategy.refine_every = int(strategy.refine_every * factor) + + +def create_splats_with_optimizers( + parser: Parser, + strategy: ScaffoldStrategy, + init_extent: float = 3.0, + init_opacity: float = 0.1, + init_scale: float = 1.0, + scene_scale: float = 1.0, + sh_degree: int = 3, + sparse_grad: bool = False, + batch_size: int = 1, + feature_dim: Optional[int] = None, + device: str = "cuda", + world_rank: int = 0, + world_size: int = 1, + voxel_size: float = 0.001, +) -> Tuple[torch.nn.ParameterDict, Dict[str, torch.optim.Optimizer]]: + + # Compare GS-Scaffold paper formula (4) + points = np.unique(np.round(parser.points/voxel_size), axis=0)*voxel_size + points = torch.from_numpy(points).float() + + # Initialize the GS size to be the average dist of the 3 nearest neighbors + dist2_avg = (knn(points, 4)[:, 1:] ** 2).mean(dim=-1) # [N,] + dist_avg = torch.sqrt(dist2_avg) + scales = torch.log(dist_avg * init_scale).unsqueeze(-1).repeat(1, 6) # [N, 3] + + # Distribute the GSs to different ranks (also works for single rank) + points = points[world_rank::world_size] + scales = scales[world_rank::world_size] + + N = points.shape[0] + quats = torch.rand((N, 4)) # [N, 4] + + features = torch.zeros((N, strategy.mean_feat_dim)) + offsets = torch.zeros((N, strategy.n_feat_offsets, 3)) + + params = [ + # name, value, lr + ("anchors", torch.nn.Parameter(points), 1.6e-4 * scene_scale), + ("scales", torch.nn.Parameter(scales), 5e-3), + ("quats", torch.nn.Parameter(quats), 1e-3), + ("opacities_mlp", strategy.opacities_mlp.parameters(), 0.02), + ("features", torch.nn.Parameter(features), 1.6e-4 * scene_scale), + ("offsets", torch.nn.Parameter(offsets), 0.004), + ("colors_mlp", strategy.colors_mlp.parameters(), 0.008), + ("scale_rot_mlp", strategy.scale_rot_mlp.parameters(), 0.004), + ] + + splats = torch.nn.ParameterDict({n: v for n, v, _ in params}).to(device) + # Scale learning rate based on batch size, reference: + # https://www.cs.princeton.edu/~smalladi/blog/2024/01/22/SDEs-ScalingRules/ + # Note that this would not make the training exactly equivalent, see + # https://arxiv.org/pdf/2402.18824v1 + BS = batch_size * world_size + optimizers = { + name: (torch.optim.SparseAdam if sparse_grad else torch.optim.Adam)( + [{"params": splats[name], "lr": lr * math.sqrt(BS), "name": name}], + eps=1e-15 / math.sqrt(BS), + # TODO: check betas logic when BS is larger than 10 betas[0] will be zero. + betas=(1 - BS * (1 - 0.9), 1 - BS * (1 - 0.999)), + ) + for name, _, lr in params + } + return splats, optimizers + + +class Runner: + """Engine for training and testing.""" + + def __init__( + self, local_rank: int, world_rank, world_size: int, cfg: Config + ) -> None: + set_random_seed(42 + local_rank) + + self.cfg = cfg + self.world_rank = world_rank + self.local_rank = local_rank + self.world_size = world_size + self.device = f"cuda:{local_rank}" + + # Where to dump results. + os.makedirs(cfg.result_dir, exist_ok=True) + + # Setup output directories. + self.ckpt_dir = f"{cfg.result_dir}/ckpts" + os.makedirs(self.ckpt_dir, exist_ok=True) + self.stats_dir = f"{cfg.result_dir}/stats" + os.makedirs(self.stats_dir, exist_ok=True) + self.render_dir = f"{cfg.result_dir}/renders" + os.makedirs(self.render_dir, exist_ok=True) + + # Tensorboard + self.writer = SummaryWriter(log_dir=f"{cfg.result_dir}/tb") + + # Load data: Training data should contain initial points and colors. + self.parser = Parser( + data_dir=cfg.data_dir, + factor=cfg.data_factor, + normalize=cfg.normalize_world_space, + test_every=cfg.test_every, + ) + self.trainset = Dataset( + self.parser, + split="train", + patch_size=cfg.patch_size, + load_depths=cfg.depth_loss, + ) + self.valset = Dataset(self.parser, split="val") + self.scene_scale = self.parser.scene_scale * 1.1 * cfg.global_scale + print("Scene scale:", self.scene_scale) + + # Model + feature_dim = 32 if cfg.app_opt else None + self.splats, self.optimizers = create_splats_with_optimizers( + self.parser, + strategy=self.cfg.strategy, + init_extent=cfg.init_extent, + init_opacity=cfg.init_opa, + init_scale=cfg.init_scale, + scene_scale=self.scene_scale, + sparse_grad=cfg.sparse_grad, + batch_size=cfg.batch_size, + feature_dim=feature_dim, + device=self.device, + world_rank=world_rank, + world_size=world_size, + voxel_size=cfg.voxel_size, + ) + print("Model initialized. Number of GS:", len(self.splats["anchors"])) + + # Densification Strategy + self.cfg.strategy.check_sanity(self.splats, self.optimizers) + + self.strategy_state = self.cfg.strategy.initialize_state( + scene_scale=self.scene_scale + ) + # Compression Strategy + self.compression_method = None + if cfg.compression is not None: + if cfg.compression == "png": + self.compression_method = PngCompression() + else: + raise ValueError(f"Unknown compression strategy: {cfg.compression}") + + self.pose_optimizers = [] + if cfg.pose_opt: + self.pose_adjust = CameraOptModule(len(self.trainset)).to(self.device) + self.pose_adjust.zero_init() + self.pose_optimizers = [ + torch.optim.Adam( + self.pose_adjust.parameters(), + lr=cfg.pose_opt_lr * math.sqrt(cfg.batch_size), + weight_decay=cfg.pose_opt_reg, + ) + ] + if world_size > 1: + self.pose_adjust = DDP(self.pose_adjust) + + if cfg.pose_noise > 0.0: + self.pose_perturb = CameraOptModule(len(self.trainset)).to(self.device) + self.pose_perturb.random_init(cfg.pose_noise) + if world_size > 1: + self.pose_perturb = DDP(self.pose_perturb) + + self.app_optimizers = [] + if cfg.app_opt: + assert feature_dim is not None + self.app_module = AppearanceOptModule( + len(self.trainset), feature_dim, cfg.app_embed_dim, cfg.sh_degree + ).to(self.device) + # initialize the last layer to be zero so that the initial output is zero. + torch.nn.init.zeros_(self.app_module.color_head[-1].weight) + torch.nn.init.zeros_(self.app_module.color_head[-1].bias) + self.app_optimizers = [ + torch.optim.Adam( + self.app_module.embeds.parameters(), + lr=cfg.app_opt_lr * math.sqrt(cfg.batch_size) * 10.0, + weight_decay=cfg.app_opt_reg, + ), + torch.optim.Adam( + self.app_module.color_head.parameters(), + lr=cfg.app_opt_lr * math.sqrt(cfg.batch_size), + ), + ] + if world_size > 1: + self.app_module = DDP(self.app_module) + + self.bil_grid_optimizers = [] + if cfg.use_bilateral_grid: + self.bil_grids = BilateralGrid( + len(self.trainset), + grid_X=cfg.bilateral_grid_shape[0], + grid_Y=cfg.bilateral_grid_shape[1], + grid_W=cfg.bilateral_grid_shape[2], + ).to(self.device) + self.bil_grid_optimizers = [ + torch.optim.Adam( + self.bil_grids.parameters(), + lr=2e-3 * math.sqrt(cfg.batch_size), + eps=1e-15, + ), + ] + + # Losses & Metrics. + self.ssim = StructuralSimilarityIndexMeasure(data_range=1.0).to(self.device) + self.psnr = PeakSignalNoiseRatio(data_range=1.0).to(self.device) + + if cfg.lpips_net == "alex": + self.lpips = LearnedPerceptualImagePatchSimilarity( + net_type="alex", normalize=True + ).to(self.device) + elif cfg.lpips_net == "vgg": + # The 3DGS official repo uses lpips vgg, which is equivalent with the following: + self.lpips = LearnedPerceptualImagePatchSimilarity( + net_type="vgg", normalize=False + ).to(self.device) + else: + raise ValueError(f"Unknown LPIPS network: {cfg.lpips_net}") + + # Viewer + if not self.cfg.disable_viewer: + self.server = viser.ViserServer(port=cfg.port, verbose=False) + self.viewer = nerfview.Viewer( + server=self.server, + render_fn=self._viewer_render_fn, + mode="training", + ) + + def get_visibility_mask( + self, + camtoworlds: Tensor, + Ks: Tensor, + width: int, + height: int, + packed: bool, + rasterize_mode: str, +): + anchors = self.splats["anchors"] # [N, 3] + # rasterization does normalization internally + quats = self.splats["quats"] # [N, 4] + scales = torch.exp(self.splats["scales"])[:, :3] # [N, 3] + + visibility_mask = filter_visible_gaussians( + means=anchors, + quats=quats, + scales=scales, + viewmats=torch.linalg.inv(camtoworlds), # [C, 4, 4] + Ks=Ks, # [C, 3, 3] + width=width, + height=height, + packed=packed, + rasterize_mode=rasterize_mode, + ) + + return visibility_mask + + def get_neural_gaussians(self, cam_pos, selection=None): + + # If no visibility mask is provided, we select all anchors including their offsets + if selection is None: + selection = torch.ones(self.splats["anchors"].shape[0], dtype=torch.bool, device=self.device) + + selected_features = self.splats["features"][selection] # [M, c] + selected_anchors = self.splats["anchors"][selection] # [M, 3] + selected_offsets = self.splats["offsets"][selection] # [M, k, 3] + selected_scales = torch.exp(self.splats["scales"][selection]) # [M, 6] + + # See formula (5) in Scaffold-GS + view_dir = selected_anchors - cam_pos # [M, 3] + view_dir_normalized = view_dir / view_dir.norm(dim=1, keepdim=True) # [M, 3] + + # See formula (9) and the appendix for the rest + feature_view_dir = torch.cat([selected_features, view_dir_normalized], dim=1) # [M, c+3] + + k = self.cfg.strategy.n_feat_offsets # Number of offsets per anchor + + # Apply MLPs (they output per-offset features concatenated along the last dimension) + neural_opacity = self.cfg.strategy.opacities_mlp(feature_view_dir) # [M, k*1] + neural_opacity = neural_opacity.view(-1, 1) # [M*k, 1] + pos_opacity_mask = (neural_opacity > 0.0).view(-1) # [M*k] + + # Get color and reshape + neural_colors = self.cfg.strategy.colors_mlp(feature_view_dir) # [M, k*3] + neural_colors = neural_colors.view(-1, 3) # [M*k, 3] + + # Get scale and rotation and reshape + neural_scale_rot = self.cfg.strategy.scale_rot_mlp(feature_view_dir) # [M, k*7] + neural_scale_rot = neural_scale_rot.view(-1, 7) # [M*k, 7] + + # Reshape selected_offsets, scales, and anchors + selected_offsets = selected_offsets.view(-1, 3) # [M*k, 3] + scales_repeated = selected_scales.unsqueeze(1).repeat(1, k, 1).view(-1, 6) # [M*k, 6] + anchors_repeated = selected_anchors.unsqueeze(1).repeat(1, k, 1).view(-1, 3) # [M*k, 3] + + # Apply positive opacity mask + selected_opacity = neural_opacity[pos_opacity_mask].squeeze(-1) # [m] + selected_colors = neural_colors[pos_opacity_mask] # [m, 3] + selected_scale_rot = neural_scale_rot[pos_opacity_mask] # [m, 7] + selected_offsets = selected_offsets[pos_opacity_mask] # [m, 3] + scales_repeated = scales_repeated[pos_opacity_mask] # [m, 6] + anchors_repeated = anchors_repeated[pos_opacity_mask] # [m, 3] + + # Compute scales and rotations + scales = scales_repeated[:, 3:] * torch.sigmoid(selected_scale_rot[:, :3]) # [m, 3] + rotation = torch.nn.functional.normalize(selected_scale_rot[:, 3:7]) # [m, 4] + + # Compute offsets and anchors + offsets = selected_offsets * scales_repeated[:, :3] # [m, 3] + anchors = anchors_repeated + offsets # [m, 3] + + return anchors, selected_colors, selected_opacity, scales, rotation, neural_opacity, pos_opacity_mask + + def rasterize_splats( + self, + camtoworlds: Tensor, + Ks: Tensor, + width: int, + height: int, + **kwargs, + ) -> Tuple[Tensor, Tensor, Dict, Tensor]: + + visibility_mask = self.get_visibility_mask(camtoworlds=camtoworlds, + Ks=Ks, + width=width, + height=height, + packed=self.cfg.packed, + rasterize_mode = "antialiased" if self.cfg.antialiased else "classic", + ) + + anchors, color_mlp, opacities, scales, quats, neural_opacity, selection_mask = self.get_neural_gaussians(camtoworlds[:, :3, 3], selection=visibility_mask) + + image_ids = kwargs.pop("image_ids", None) + if self.cfg.app_opt: + colors = self.app_module( + features=self.splats["features"], + embed_ids=image_ids, + dirs=anchors[None, :, :] - camtoworlds[:, None, :3, 3], + sh_degree=kwargs.pop("sh_degree", self.cfg.sh_degree), + ) + colors = colors + color_mlp + colors = torch.sigmoid(colors) + else: + colors = color_mlp # [N, K, 3] + + rasterize_mode = "antialiased" if self.cfg.antialiased else "classic" + render_colors, render_alphas, info = rasterization( + means=anchors, + quats=quats, + scales=scales, + opacities=opacities, + colors=colors, + viewmats=torch.linalg.inv(camtoworlds), # [C, 4, 4] + Ks=Ks, # [C, 3, 3] + width=width, + height=height, + packed=self.cfg.packed, + absgrad=self.cfg.strategy.absgrad, + sparse_grad=self.cfg.sparse_grad, + rasterize_mode=rasterize_mode, + distributed=self.world_size > 1, + **kwargs, + ) + return render_colors, render_alphas, info, scales + + def train(self): + cfg = self.cfg + device = self.device + world_rank = self.world_rank + world_size = self.world_size + + # Dump cfg. + if world_rank == 0: + with open(f"{cfg.result_dir}/cfg.yml", "w") as f: + yaml.dump(vars(cfg), f) + + max_steps = cfg.max_steps + init_step = 0 + + schedulers = [ + # anchors has a learning rate schedule, that end at 0.01 of the initial value + torch.optim.lr_scheduler.ExponentialLR( + self.optimizers["anchors"], gamma=0.01 ** (1.0 / max_steps) + ), + ] + if cfg.pose_opt: + # pose optimization has a learning rate schedule + schedulers.append( + torch.optim.lr_scheduler.ExponentialLR( + self.pose_optimizers[0], gamma=0.01 ** (1.0 / max_steps) + ) + ) + if cfg.use_bilateral_grid: + # bilateral grid has a learning rate schedule. Linear warmup for 1000 steps. + schedulers.append( + torch.optim.lr_scheduler.ChainedScheduler( + [ + torch.optim.lr_scheduler.LinearLR( + self.bil_grid_optimizers[0], + start_factor=0.01, + total_iters=1000, + ), + torch.optim.lr_scheduler.ExponentialLR( + self.bil_grid_optimizers[0], gamma=0.01 ** (1.0 / max_steps) + ), + ] + ) + ) + + trainloader = torch.utils.data.DataLoader( + self.trainset, + batch_size=cfg.batch_size, + shuffle=True, + num_workers=4, + persistent_workers=True, + pin_memory=True, + ) + trainloader_iter = iter(trainloader) + + # Training loop. + global_tic = time.time() + pbar = tqdm.tqdm(range(init_step, max_steps)) + for step in pbar: + if not cfg.disable_viewer: + while self.viewer.state.status == "paused": + time.sleep(0.01) + self.viewer.lock.acquire() + tic = time.time() + + try: + data = next(trainloader_iter) + except StopIteration: + trainloader_iter = iter(trainloader) + data = next(trainloader_iter) + + camtoworlds = camtoworlds_gt = data["camtoworld"].to(device) # [1, 4, 4] + Ks = data["K"].to(device) # [1, 3, 3] + pixels = data["image"].to(device) / 255.0 # [1, H, W, 3] + num_train_rays_per_step = ( + pixels.shape[0] * pixels.shape[1] * pixels.shape[2] + ) + image_ids = data["image_id"].to(device) + if cfg.depth_loss: + points = data["points"].to(device) # [1, M, 2] + depths_gt = data["depths"].to(device) # [1, M] + + height, width = pixels.shape[1:3] + + if cfg.pose_noise: + camtoworlds = self.pose_perturb(camtoworlds, image_ids) + + if cfg.pose_opt: + camtoworlds = self.pose_adjust(camtoworlds, image_ids) + + # forward + renders, alphas, info, scales = self.rasterize_splats( + camtoworlds=camtoworlds, + Ks=Ks, + width=width, + height=height, + sh_degree=None, + near_plane=cfg.near_plane, + far_plane=cfg.far_plane, + image_ids=image_ids, + render_mode="RGB+ED" if cfg.depth_loss else "RGB", + ) + if renders.shape[-1] == 4: + colors, depths = renders[..., 0:3], renders[..., 3:4] + else: + colors, depths = renders, None + + if cfg.use_bilateral_grid: + grid_y, grid_x = torch.meshgrid( + (torch.arange(height, device=self.device) + 0.5) / height, + (torch.arange(width, device=self.device) + 0.5) / width, + indexing="ij", + ) + grid_xy = torch.stack([grid_x, grid_y], dim=-1).unsqueeze(0) + colors = slice(self.bil_grids, grid_xy, colors, image_ids)["rgb"] + + if cfg.random_bkgd: + bkgd = torch.rand(1, 3, device=device) + colors = colors + bkgd * (1.0 - alphas) + + self.cfg.strategy.step_pre_backward( + params=self.splats, + optimizers=self.optimizers, + state=self.strategy_state, + step=step, + info=info, + ) + + # loss + l1loss = F.l1_loss(colors, pixels) + ssimloss = 1.0 - fused_ssim( + colors.permute(0, 3, 1, 2), pixels.permute(0, 3, 1, 2), padding="valid" + ) + loss = l1loss * (1.0 - cfg.ssim_lambda) + loss += ssimloss * cfg.ssim_lambda + loss += scales.prod(dim=1).mean() * cfg.scale_reg + if cfg.depth_loss: + # query depths from depth map + points = torch.stack( + [ + points[:, :, 0] / (width - 1) * 2 - 1, + points[:, :, 1] / (height - 1) * 2 - 1, + ], + dim=-1, + ) # normalize to [-1, 1] + grid = points.unsqueeze(2) # [1, M, 1, 2] + depths = F.grid_sample( + depths.permute(0, 3, 1, 2), grid, align_corners=True + ) # [1, 1, M, 1] + depths = depths.squeeze(3).squeeze(1) # [1, M] + # calculate loss in disparity space + disp = torch.where(depths > 0.0, 1.0 / depths, torch.zeros_like(depths)) + disp_gt = 1.0 / depths_gt # [1, M] + depthloss = F.l1_loss(disp, disp_gt) * self.scene_scale + loss += depthloss * cfg.depth_lambda + if cfg.use_bilateral_grid: + tvloss = 10 * total_variation_loss(self.bil_grids.grids) + loss += tvloss + + # regularizations + # not gonna work. Check this + # if cfg.opacity_reg > 0.0: + # loss = ( + # loss + # + cfg.opacity_reg + # * torch.abs(torch.sigmoid(self.splats["opacities"])).mean() + # ) + + loss.backward() + + desc = f"loss={loss.item():.3f}| " + if cfg.depth_loss: + desc += f"depth loss={depthloss.item():.6f}| " + if cfg.pose_opt and cfg.pose_noise: + # monitor the pose error if we inject noise + pose_err = F.l1_loss(camtoworlds_gt, camtoworlds) + desc += f"pose err={pose_err.item():.6f}| " + pbar.set_description(desc) + + if world_rank == 0 and cfg.tb_every > 0 and step % cfg.tb_every == 0: + mem = torch.cuda.max_memory_allocated() / 1024**3 + self.writer.add_scalar("train/loss", loss.item(), step) + self.writer.add_scalar("train/l1loss", l1loss.item(), step) + self.writer.add_scalar("train/ssimloss", ssimloss.item(), step) + self.writer.add_scalar("train/num_GS", len(self.splats["anchors"]), step) + self.writer.add_scalar("train/mem", mem, step) + if cfg.depth_loss: + self.writer.add_scalar("train/depthloss", depthloss.item(), step) + if cfg.use_bilateral_grid: + self.writer.add_scalar("train/tvloss", tvloss.item(), step) + if cfg.tb_save_image: + canvas = torch.cat([pixels, colors], dim=2).detach().cpu().numpy() + canvas = canvas.reshape(-1, *canvas.shape[2:]) + self.writer.add_image("train/render", canvas, step) + self.writer.flush() + + # save checkpoint before updating the model + if step in [i - 1 for i in cfg.save_steps] or step == max_steps - 1: + mem = torch.cuda.max_memory_allocated() / 1024**3 + stats = { + "mem": mem, + "ellipse_time": time.time() - global_tic, + "num_GS": len(self.splats["anchors"]), + } + print("Step: ", step, stats) + with open( + f"{self.stats_dir}/train_step{step:04d}_rank{self.world_rank}.json", + "w", + ) as f: + json.dump(stats, f) + data = {"step": step, "splats": self.splats.state_dict()} + if cfg.pose_opt: + if world_size > 1: + data["pose_adjust"] = self.pose_adjust.module.state_dict() + else: + data["pose_adjust"] = self.pose_adjust.state_dict() + if cfg.app_opt: + if world_size > 1: + data["app_module"] = self.app_module.module.state_dict() + else: + data["app_module"] = self.app_module.state_dict() + torch.save( + data, f"{self.ckpt_dir}/ckpt_{step}_rank{self.world_rank}.pt" + ) + + # For now no post steps + # self.cfg.strategy.step_post_backward( + # params=self.splats, + # optimizers=self.optimizers, + # state=self.strategy_state, + # step=step, + # info=info, + # packed=cfg.packed, + # ) + + # Turn Gradients into Sparse Tensor before running optimizer + if cfg.sparse_grad: + assert cfg.packed, "Sparse gradients only work with packed mode." + gaussian_ids = info["gaussian_ids"] + for k in self.splats.keys(): + grad = self.splats[k].grad + if grad is None or grad.is_sparse: + continue + self.splats[k].grad = torch.sparse_coo_tensor( + indices=gaussian_ids[None], # [1, nnz] + values=grad[gaussian_ids], # [nnz, ...] + size=self.splats[k].size(), # [N, ...] + is_coalesced=len(Ks) == 1, + ) + + # optimize + for optimizer in self.optimizers.values(): + optimizer.step() + optimizer.zero_grad(set_to_none=True) + for optimizer in self.pose_optimizers: + optimizer.step() + optimizer.zero_grad(set_to_none=True) + for optimizer in self.app_optimizers: + optimizer.step() + optimizer.zero_grad(set_to_none=True) + for optimizer in self.bil_grid_optimizers: + optimizer.step() + optimizer.zero_grad(set_to_none=True) + for scheduler in schedulers: + scheduler.step() + + # eval the full set + if step in [i - 1 for i in cfg.eval_steps]: + self.eval(step) + self.render_traj(step) + + # run compression + if cfg.compression is not None and step in [i - 1 for i in cfg.eval_steps]: + self.run_compression(step=step) + + if not cfg.disable_viewer: + self.viewer.lock.release() + num_train_steps_per_sec = 1.0 / (time.time() - tic) + num_train_rays_per_sec = ( + num_train_rays_per_step * num_train_steps_per_sec + ) + # Update the viewer state. + self.viewer.state.num_train_rays_per_sec = num_train_rays_per_sec + # Update the scene. + self.viewer.update(step, num_train_rays_per_step) + + @torch.no_grad() + def eval(self, step: int, stage: str = "val"): + """Entry for evaluation.""" + print("Running evaluation...") + cfg = self.cfg + device = self.device + world_rank = self.world_rank + world_size = self.world_size + + valloader = torch.utils.data.DataLoader( + self.valset, batch_size=1, shuffle=False, num_workers=1 + ) + ellipse_time = 0 + metrics = defaultdict(list) + for i, data in enumerate(valloader): + camtoworlds = data["camtoworld"].to(device) + Ks = data["K"].to(device) + pixels = data["image"].to(device) / 255.0 + height, width = pixels.shape[1:3] + + torch.cuda.synchronize() + tic = time.time() + colors, _, _, _ = self.rasterize_splats( + camtoworlds=camtoworlds, + Ks=Ks, + width=width, + height=height, + sh_degree=cfg.sh_degree, + near_plane=cfg.near_plane, + far_plane=cfg.far_plane, + ) # [1, H, W, 3] + torch.cuda.synchronize() + ellipse_time += time.time() - tic + + colors = torch.clamp(colors, 0.0, 1.0) + canvas_list = [pixels, colors] + + if world_rank == 0: + # write images + canvas = torch.cat(canvas_list, dim=2).squeeze(0).cpu().numpy() + canvas = (canvas * 255).astype(np.uint8) + imageio.imwrite( + f"{self.render_dir}/{stage}_step{step}_{i:04d}.png", + canvas, + ) + + pixels_p = pixels.permute(0, 3, 1, 2) # [1, 3, H, W] + colors_p = colors.permute(0, 3, 1, 2) # [1, 3, H, W] + metrics["psnr"].append(self.psnr(colors_p, pixels_p)) + metrics["ssim"].append(self.ssim(colors_p, pixels_p)) + metrics["lpips"].append(self.lpips(colors_p, pixels_p)) + if cfg.use_bilateral_grid: + cc_colors = color_correct(colors, pixels) + cc_colors_p = cc_colors.permute(0, 3, 1, 2) # [1, 3, H, W] + metrics["cc_psnr"].append(self.psnr(cc_colors_p, pixels_p)) + + if world_rank == 0: + ellipse_time /= len(valloader) + + stats = {k: torch.stack(v).mean().item() for k, v in metrics.items()} + stats.update( + { + "ellipse_time": ellipse_time, + "num_GS": len(self.splats["anchors"]), + } + ) + print( + f"PSNR: {stats['psnr']:.3f}, SSIM: {stats['ssim']:.4f}, LPIPS: {stats['lpips']:.3f} " + f"Time: {stats['ellipse_time']:.3f}s/image " + f"Number of GS: {stats['num_GS']}" + ) + # save stats as json + with open(f"{self.stats_dir}/{stage}_step{step:04d}.json", "w") as f: + json.dump(stats, f) + # save stats to tensorboard + for k, v in stats.items(): + self.writer.add_scalar(f"{stage}/{k}", v, step) + self.writer.flush() + + @torch.no_grad() + def render_traj(self, step: int): + """Entry for trajectory rendering.""" + print("Running trajectory rendering...") + cfg = self.cfg + device = self.device + + camtoworlds_all = self.parser.camtoworlds[5:-5] + if cfg.render_traj_path == "interp": + camtoworlds_all = generate_interpolated_path( + camtoworlds_all, 1 + ) # [N, 3, 4] + elif cfg.render_traj_path == "ellipse": + height = camtoworlds_all[:, 2, 3].mean() + camtoworlds_all = generate_ellipse_path_z( + camtoworlds_all, height=height + ) # [N, 3, 4] + elif cfg.render_traj_path == "spiral": + camtoworlds_all = generate_spiral_path( + camtoworlds_all, + bounds=self.parser.bounds * self.scene_scale, + spiral_scale_r=self.parser.extconf["spiral_radius_scale"], + ) + else: + raise ValueError( + f"Render trajectory type not supported: {cfg.render_traj_path}" + ) + + camtoworlds_all = np.concatenate( + [ + camtoworlds_all, + np.repeat( + np.array([[[0.0, 0.0, 0.0, 1.0]]]), len(camtoworlds_all), axis=0 + ), + ], + axis=1, + ) # [N, 4, 4] + + camtoworlds_all = torch.from_numpy(camtoworlds_all).float().to(device) + K = torch.from_numpy(list(self.parser.Ks_dict.values())[0]).float().to(device) + width, height = list(self.parser.imsize_dict.values())[0] + + canvas_all = [] + for i in tqdm.trange(len(camtoworlds_all), desc="Rendering trajectory"): + camtoworlds = camtoworlds_all[i : i + 1] + Ks = K[None] + + renders, _, _, _ = self.rasterize_splats( + camtoworlds=camtoworlds, + Ks=Ks, + width=width, + height=height, + sh_degree=cfg.sh_degree, + near_plane=cfg.near_plane, + far_plane=cfg.far_plane, + render_mode="RGB+ED", + ) # [1, H, W, 4] + colors = torch.clamp(renders[..., 0:3], 0.0, 1.0) # [1, H, W, 3] + depths = renders[..., 3:4] # [1, H, W, 1] + depths = (depths - depths.min()) / (depths.max() - depths.min()) + canvas_list = [colors, depths.repeat(1, 1, 1, 3)] + + # write images + canvas = torch.cat(canvas_list, dim=2).squeeze(0).cpu().numpy() + canvas = (canvas * 255).astype(np.uint8) + canvas_all.append(canvas) + + # save to video + video_dir = f"{cfg.result_dir}/videos" + os.makedirs(video_dir, exist_ok=True) + writer = imageio.get_writer(f"{video_dir}/traj_{step}.mp4", fps=30) + for canvas in canvas_all: + writer.append_data(canvas) + writer.close() + print(f"Video saved to {video_dir}/traj_{step}.mp4") + + @torch.no_grad() + def run_compression(self, step: int): + """Entry for running compression.""" + print("Running compression...") + world_rank = self.world_rank + + compress_dir = f"{cfg.result_dir}/compression/rank{world_rank}" + os.makedirs(compress_dir, exist_ok=True) + + self.compression_method.compress(compress_dir, self.splats) + + # evaluate compression + splats_c = self.compression_method.decompress(compress_dir) + for k in splats_c.keys(): + self.splats[k].data = splats_c[k].to(self.device) + self.eval(step=step, stage="compress") + + @torch.no_grad() + def _viewer_render_fn( + self, camera_state: nerfview.CameraState, img_wh: Tuple[int, int] + ): + """Callable function for the viewer.""" + W, H = img_wh + c2w = camera_state.c2w + K = camera_state.get_K(img_wh) + c2w = torch.from_numpy(c2w).float().to(self.device) + K = torch.from_numpy(K).float().to(self.device) + + render_colors, _, _, _ = self.rasterize_splats( + camtoworlds=c2w[None], + Ks=K[None], + width=W, + height=H, + sh_degree=None, # active all SH degrees + radius_clip=3.0, # skip GSs that have small image radius (in pixels) + ) # [1, H, W, 3] + return render_colors[0].cpu().numpy() + + +def main(local_rank: int, world_rank, world_size: int, cfg: Config): + if world_size > 1 and not cfg.disable_viewer: + cfg.disable_viewer = True + if world_rank == 0: + print("Viewer is disabled in distributed training.") + + runner = Runner(local_rank, world_rank, world_size, cfg) + + if cfg.ckpt is not None: + # run eval only + ckpts = [ + torch.load(file, map_location=runner.device, weights_only=True) + for file in cfg.ckpt + ] + for k in runner.splats.keys(): + runner.splats[k].data = torch.cat([ckpt["splats"][k] for ckpt in ckpts]) + step = ckpts[0]["step"] + runner.eval(step=step) + runner.render_traj(step=step) + if cfg.compression is not None: + runner.run_compression(step=step) + else: + runner.train() + + if not cfg.disable_viewer: + print("Viewer running... Ctrl+C to exit.") + time.sleep(1000000) + + +if __name__ == "__main__": + """ + Usage: + + ```bash + # Single GPU training + CUDA_VISIBLE_DEVICES=0 python simple_trainer.py default + + # Distributed training on 4 GPUs: Effectively 4x batch size so run 4x less steps. + CUDA_VISIBLE_DEVICES=0,1,2,3 python simple_trainer.py default --steps_scaler 0.25 + + """ + + cfg = tyro.cli(Config) + cfg.adjust_steps(cfg.steps_scaler) + + # try import extra dependencies + if cfg.compression == "png": + try: + import plas + import torchpq + except: + raise ImportError( + "To use PNG compression, you need to install " + "torchpq (instruction at https://github.com/DeMoriarty/TorchPQ?tab=readme-ov-file#install) " + "and plas (via 'pip install git+https://github.com/fraunhoferhhi/PLAS.git') " + ) + + cli(main, cfg, verbose=True) diff --git a/gsplat/rendering.py b/gsplat/rendering.py index cadacaa2f..174f8121d 100644 --- a/gsplat/rendering.py +++ b/gsplat/rendering.py @@ -582,6 +582,218 @@ def reshape_view(C: int, world_view: torch.Tensor, N_world: list) -> torch.Tenso return render_colors, render_alphas, meta +def filter_visible_gaussians( + means: Tensor, # [N, 3] + quats: Tensor, # [N, 4] + scales: Tensor, # [N, 3] + viewmats: Tensor, # [C, 4, 4] + Ks: Tensor, # [C, 3, 3] + width: int, + height: int, + near_plane: float = 0.01, + far_plane: float = 1e10, + radius_clip: float = 0.0, + eps2d: float = 0.3, + sh_degree: Optional[int] = None, + packed: bool = True, + rasterize_mode: Literal["classic", "antialiased"] = "classic", + ortho: bool = False, + covars: Optional[Tensor] = None, +) -> Tuple[Tensor, Tensor, Dict]: + """Rasterize a set of 3D Gaussians (N) to a batch of image planes (C). + + This function provides a handful features for 3D Gaussian rasterization, which + we detail in the following notes. A complete profiling of the these features + can be found in the :ref:`profiling` page. + + .. note:: + **Multi-GPU Distributed Rasterization**: This function can be used in a multi-GPU + distributed scenario by setting `distributed` to True. When `distributed` is True, + a subset of total Gaussians could be passed into this function in each rank, and + the function will collaboratively render a set of images using Gaussians from all ranks. Note + to achieve balanced computation, it is recommended (not enforced) to have similar number of + Gaussians in each rank. But we do enforce that the number of cameras to be rendered + in each rank is the same. The function will return the rendered images + corresponds to the input cameras in each rank, and allows for gradients to flow back to the + Gaussians living in other ranks. For the details, please refer to the paper + `On Scaling Up 3D Gaussian Splatting Training `_. + + .. note:: + **Batch Rasterization**: This function allows for rasterizing a set of 3D Gaussians + to a batch of images in one go, by simplly providing the batched `viewmats` and `Ks`. + + .. note:: + **Support N-D Features**: If `sh_degree` is None, + the `colors` is expected to be with shape [N, D] or [C, N, D], in which D is the channel of + the features to be rendered. The computation is slow when D > 32 at the moment. + If `sh_degree` is set, the `colors` is expected to be the SH coefficients with + shape [N, K, 3] or [C, N, K, 3], where K is the number of SH bases. In this case, it is expected + that :math:`(\\textit{sh_degree} + 1) ^ 2 \\leq K`, where `sh_degree` controls the + activated bases in the SH coefficients. + + .. note:: + **Depth Rendering**: This function supports colors or/and depths via `render_mode`. + The supported modes are "RGB", "D", "ED", "RGB+D", and "RGB+ED". "RGB" renders the + colored image that respects the `colors` argument. "D" renders the accumulated z-depth + :math:`\\sum_i w_i z_i`. "ED" renders the expected z-depth + :math:`\\frac{\\sum_i w_i z_i}{\\sum_i w_i}`. "RGB+D" and "RGB+ED" render both + the colored image and the depth, in which the depth is the last channel of the output. + + .. note:: + **Memory-Speed Trade-off**: The `packed` argument provides a trade-off between + memory footprint and runtime. If `packed` is True, the intermediate results are + packed into sparse tensors, which is more memory efficient but might be slightly + slower. This is especially helpful when the scene is large and each camera sees only + a small portion of the scene. If `packed` is False, the intermediate results are + with shape [C, N, ...], which is faster but might consume more memory. + + .. note:: + **Sparse Gradients**: If `sparse_grad` is True, the gradients for {means, quats, scales} + will be stored in a `COO sparse layout `_. + This can be helpful for saving memory + for training when the scene is large and each iteration only activates a small portion + of the Gaussians. Usually a sparse optimizer is required to work with sparse gradients, + such as `torch.optim.SparseAdam `_. + This argument is only effective when `packed` is True. + + .. note:: + **Speed-up for Large Scenes**: The `radius_clip` argument is extremely helpful for + speeding up large scale scenes or scenes with large depth of fields. Gaussians with + 2D radius smaller or equal than this value (in pixel unit) will be skipped during rasterization. + This will skip all the far-away Gaussians that are too small to be seen in the image. + But be warned that if there are close-up Gaussians that are also below this threshold, they will + also get skipped (which is rarely happened in practice). This is by default disabled by setting + `radius_clip` to 0.0. + + .. note:: + **Antialiased Rendering**: If `rasterize_mode` is "antialiased", the function will + apply a view-dependent compensation factor + :math:`\\rho=\\sqrt{\\frac{Det(\\Sigma)}{Det(\\Sigma+ \\epsilon I)}}` to Gaussian + opacities, where :math:`\\Sigma` is the projected 2D covariance matrix and :math:`\\epsilon` + is the `eps2d`. This will make the rendered image more antialiased, as proposed in + the paper `Mip-Splatting: Alias-free 3D Gaussian Splatting `_. + + .. note:: + **AbsGrad**: If `absgrad` is True, the absolute gradients of the projected + 2D means will be computed during the backward pass, which could be accessed by + `meta["means2d"].absgrad`. This is an implementation of the paper + `AbsGS: Recovering Fine Details for 3D Gaussian Splatting `_, + which is shown to be more effective for splitting Gaussians during training. + + .. warning:: + This function is currently not differentiable w.r.t. the camera intrinsics `Ks`. + + Args: + means: The 3D centers of the Gaussians. [N, 3] + quats: The quaternions of the Gaussians (wxyz convension). It's not required to be normalized. [N, 4] + scales: The scales of the Gaussians. [N, 3] + viewmats: The world-to-cam transformation of the cameras. [C, 4, 4] + Ks: The camera intrinsics. [C, 3, 3] + width: The width of the image. + height: The height of the image. + near_plane: The near plane for clipping. Default is 0.01. + far_plane: The far plane for clipping. Default is 1e10. + radius_clip: Gaussians with 2D radius smaller or equal than this value will be + skipped. This is extremely helpful for speeding up large scale scenes. + Default is 0.0. + eps2d: An epsilon added to the egienvalues of projected 2D covariance matrices. + This will prevents the projected GS to be too small. For example eps2d=0.3 + leads to minimal 3 pixel unit. Default is 0.3. + sh_degree: The SH degree to use, which can be smaller than the total + number of bands. If set, the `colors` should be [(C,) N, K, 3] SH coefficients, + else the `colors` should [(C,) N, D] post-activation color values. Default is None. + packed: Whether to use packed mode which is more memory efficient but might or + might not be as fast. Default is True. + tile_size: The size of the tiles for rasterization. Default is 16. + (Note: other values are not tested) + backgrounds: The background colors. [C, D]. Default is None. + render_mode: The rendering mode. Supported modes are "RGB", "D", "ED", "RGB+D", + and "RGB+ED". "RGB" renders the colored image, "D" renders the accumulated depth, and + "ED" renders the expected depth. Default is "RGB". + sparse_grad: If true, the gradients for {means, quats, scales} will be stored in + a COO sparse layout. This can be helpful for saving memory. Default is False. + absgrad: If true, the absolute gradients of the projected 2D means + will be computed during the backward pass, which could be accessed by + `meta["means2d"].absgrad`. Default is False. + rasterize_mode: The rasterization mode. Supported modes are "classic" and + "antialiased". Default is "classic". + channel_chunk: The number of channels to render in one go. Default is 32. + If the required rendering channels are larger than this value, the rendering + will be done looply in chunks. + distributed: Whether to use distributed rendering. Default is False. If True, + The input Gaussians are expected to be a subset of scene in each rank, and + the function will collaboratively render the images for all ranks. + ortho: Whether to use orthographic projection. In such case fx and fy become the scaling + factors to convert projected coordinates into pixel space and cx, cy become offsets. + covars: Optional covariance matrices of the Gaussians. If provided, the `quats` and + `scales` will be ignored. [N, 3, 3], Default is None. + + Returns: + A tuple: + + **render_colors**: The rendered colors. [C, height, width, X]. + X depends on the `render_mode` and input `colors`. If `render_mode` is "RGB", + X is D; if `render_mode` is "D" or "ED", X is 1; if `render_mode` is "RGB+D" or + "RGB+ED", X is D+1. + + **render_alphas**: The rendered alphas. [C, height, width, 1]. + + **meta**: A dictionary of intermediate results of the rasterization. + + """ + N = means.shape[0] + C = viewmats.shape[0] + assert means.shape == (N, 3), means.shape + if covars is None: + assert quats.shape == (N, 4), quats.shape + assert scales.shape == (N, 3), scales.shape + else: + assert covars.shape == (N, 3, 3), covars.shape + quats, scales = None, None + # convert covars from 3x3 matrix to upper-triangular 6D vector + tri_indices = ([0, 0, 0, 1, 1, 2], [0, 1, 2, 1, 2, 2]) + covars = covars[..., tri_indices[0], tri_indices[1]] + assert viewmats.shape == (C, 4, 4), viewmats.shape + assert Ks.shape == (C, 3, 3), Ks.shape + + # Project Gaussians to 2D. Directly pass in {quats, scales} is faster than precomputing covars. + proj_results = fully_fused_projection( + means, + covars, + quats, + scales, + viewmats, + Ks, + width, + height, + eps2d=eps2d, + packed=packed, + near_plane=near_plane, + far_plane=far_plane, + radius_clip=radius_clip, + sparse_grad=False, + calc_compensations=(rasterize_mode == "antialiased"), + ortho=ortho, + ) + + if packed: + # The results are packed into shape [nnz, ...]. All elements are valid. + ( + _, + _, + radii, + _, + _, + _, + _, + ) = proj_results + else: + # The results are with shape [C, N, ...]. Only the elements with radii > 0 are valid. + radii, _, _, _, _ = proj_results + + return radii.squeeze(0) > 0 + + def _rasterization( means: Tensor, # [N, 3] quats: Tensor, # [N, 4] @@ -1447,7 +1659,7 @@ def rasterization_2dgs_inria_wrapper( render_depth_expected * (1 - depth_ratio) + (depth_ratio) * render_depth_median ) - normals_surf = depth_to_normal(render_depth, torch.linalg.inv(viewmats), Ks) + normals_surf = depth_to_normal(render_depth, viewmats, Ks) normals_surf = normals_surf * (render_alphas).detach() render_colors = torch.cat([render_colors, render_depth], dim=-1) diff --git a/gsplat/strategy/__init__.py b/gsplat/strategy/__init__.py index 305dc8129..08ac72b8f 100644 --- a/gsplat/strategy/__init__.py +++ b/gsplat/strategy/__init__.py @@ -1,3 +1,4 @@ from .base import Strategy from .default import DefaultStrategy from .mcmc import MCMCStrategy +from .scaffold import ScaffoldStrategy diff --git a/gsplat/strategy/scaffold.py b/gsplat/strategy/scaffold.py new file mode 100644 index 000000000..b62c777ff --- /dev/null +++ b/gsplat/strategy/scaffold.py @@ -0,0 +1,309 @@ +from dataclasses import dataclass +from typing import Any, Dict, Tuple, Union + +import torch + +from .base import Strategy +from .ops import duplicate, remove + + +@dataclass +class ScaffoldStrategy(Strategy): + """A neural gaussian strategy that follows the paper: + + `Scaffold-GS: Structured 3D Gaussians for View-Adaptive Rendering `_ + + The strategy will: + + - Periodically duplicate GSs with high image plane gradients and small scales. + - Periodically split GSs with high image plane gradients and large scales. + - Periodically prune GSs with low opacity. + - Periodically reset GSs to a lower opacity. + + If `absgrad=True`, it will use the absolute gradients instead of average gradients + for GS duplicating & splitting, following the AbsGS paper: + + `AbsGS: Recovering Fine Details for 3D Gaussian Splatting `_ + + Which typically leads to better results but requires to set the `grow_grad2d` to a + higher value, e.g., 0.0008. Also, the :func:`rasterization` function should be called + with `absgrad=True` as well so that the absolute gradients are computed. + + Args: + prune_opa (float): GSs with opacity below this value will be pruned. Default is 0.005. + grow_grad2d (float): GSs with image plane gradient above this value will be + split/duplicated. Default is 0.0002. + grow_scale3d (float): GSs with 3d scale (normalized by scene_scale) below this + value will be duplicated. Above will be split. Default is 0.01. + grow_scale2d (float): GSs with 2d scale (normalized by image resolution) above + this value will be split. Default is 0.05. + prune_scale3d (float): GSs with 3d scale (normalized by scene_scale) above this + value will be pruned. Default is 0.1. + prune_scale2d (float): GSs with 2d scale (normalized by image resolution) above + this value will be pruned. Default is 0.15. + refine_scale2d_stop_iter (int): Stop refining GSs based on 2d scale after this + iteration. Default is 0. Set to a positive value to enable this feature. + refine_start_iter (int): Start refining GSs after this iteration. Default is 500. + refine_stop_iter (int): Stop refining GSs after this iteration. Default is 15_000. + refine_every (int): Refine GSs every this steps. Default is 100. + pause_refine_after_reset (int): Pause refining GSs until this number of steps after + reset, Default is 0 (no pause at all) and one might want to set this number to the + number of images in training set. + absgrad (bool): Use absolute gradients for GS splitting. Default is False. + revised_opacity (bool): Whether to use revised opacity heuristic from + arXiv:2404.06109 (experimental). Default is False. + verbose (bool): Whether to print verbose information. Default is False. + + Examples: + + >>> from gsplat import DefaultStrategy, rasterization + >>> params: Dict[str, torch.nn.Parameter] | torch.nn.ParameterDict = ... + >>> optimizers: Dict[str, torch.optim.Optimizer] = ... + >>> strategy = DefaultStrategy() + >>> strategy.check_sanity(params, optimizers) + >>> strategy_state = strategy.initialize_state() + >>> for step in range(1000): + ... render_image, render_alpha, info = rasterization(...) + ... strategy.step_pre_backward(params, optimizers, strategy_state, step, info) + ... loss = ... + ... loss.backward() + ... strategy.step_post_backward(params, optimizers, strategy_state, step, info) + + """ + + prune_opa: float = 0.005 + grow_grad2d: float = 0.0002 + grow_scale3d: float = 0.01 + grow_scale2d: float = 0.05 + prune_scale3d: float = 0.1 + prune_scale2d: float = 0.15 + refine_scale2d_stop_iter: int = 0 + refine_start_iter: int = 500 + refine_stop_iter: int = 15_000 + mean_feat_dim: int = 32 + n_feat_offsets: int = 10 + refine_every: int = 100 + pause_refine_after_reset: int = 0 + absgrad: bool = False + revised_opacity: bool = False + verbose: bool = False + colors_mlp: torch.nn.Sequential = torch.nn.Sequential( + torch.nn.Linear(mean_feat_dim + 3, mean_feat_dim), + torch.nn.ReLU(True), + torch.nn.Linear(mean_feat_dim, 3 * n_feat_offsets), + torch.nn.Sigmoid() + ).cuda() + opacities_mlp: torch.nn.Sequential = torch.nn.Sequential( + torch.nn.Linear(mean_feat_dim + 3, mean_feat_dim), + torch.nn.ReLU(True), + torch.nn.Linear(mean_feat_dim, n_feat_offsets), + torch.nn.Tanh() + ).cuda() + scale_rot_mlp: torch.nn.Sequential = torch.nn.Sequential( + torch.nn.Linear(mean_feat_dim + 3, mean_feat_dim), + torch.nn.ReLU(True), + torch.nn.Linear(mean_feat_dim, 7 * n_feat_offsets) + ).cuda() + + def initialize_state(self, scene_scale: float = 1.0) -> Dict[str, Any]: + """Initialize and return the running state for this strategy. + + The returned state should be passed to the `step_pre_backward()` and + `step_post_backward()` functions. + """ + # Postpone the initialization of the state to the first step so that we can + # put them on the correct device. + # - grad2d: running accum of the norm of the image plane gradients for each GS. + # - count: running accum of how many time each GS is visible. + # - radii: the radii of the GSs (normalized by the image resolution). + state = {"grad2d": None, + "count": None, + "scene_scale": scene_scale} + if self.refine_scale2d_stop_iter > 0: + state["radii"] = None + return state + + def check_sanity( + self, + params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict], + optimizers: Dict[str, torch.optim.Optimizer], + ): + """Sanity check for the parameters and optimizers. + + Check if: + * `params` and `optimizers` have the same keys. + * Each optimizer has exactly one param_group, corresponding to each parameter. + * The following keys are present: {"anchors", "scales", "quats", "opacities"}. + + Raises: + AssertionError: If any of the above conditions is not met. + + .. note:: + It is not required but highly recommended for the user to call this function + after initializing the strategy to ensure the convention of the parameters + and optimizers is as expected. + """ + + super().check_sanity(params, optimizers) + # The following keys are required for this strategy. + expected_params = ["anchors", + "features", + "offsets", + "scales", + "quats", + "opacities_mlp", + "colors_mlp", + "scale_rot_mlp"] + + assert len(expected_params) == len(params), "expected params and actual params don't match" + for key in expected_params: + assert key in params, f"{key} is required in params but missing." + + def step_pre_backward( + self, + params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict], + optimizers: Dict[str, torch.optim.Optimizer], + state: Dict[str, Any], + step: int, + info: Dict[str, Any], + ): + """Callback function to be executed before the `loss.backward()` call.""" + assert ( + "means2d" in info + ), "The 2D anchors of the Gaussians is required but missing." + info["means2d"].retain_grad() + + def step_post_backward( + self, + params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict], + optimizers: Dict[str, torch.optim.Optimizer], + state: Dict[str, Any], + step: int, + info: Dict[str, Any], + packed: bool = False, + ): + """Callback function to be executed after the `loss.backward()` call.""" + if step >= self.refine_stop_iter: + return + + self._update_state(params, state, info, packed=packed) + + if ( + step > self.refine_start_iter + and step % self.refine_every == 0 + ): + # grow GSs + n_dupli, n_split = self._grow_anchors(params, optimizers, state, step) + if self.verbose: + print( + f"Step {step}: {n_dupli} GSs duplicated, {n_split} GSs split. " + f"Now having {len(params['anchors'])} GSs." + ) + + # prune GSs + n_prune = self._prune_gs(params, optimizers, state, step) + if self.verbose: + print( + f"Step {step}: {n_prune} GSs pruned. " + f"Now having {len(params['anchors'])} GSs." + ) + + # reset running stats + state["grad2d"].zero_() + state["count"].zero_() + + if self.refine_scale2d_stop_iter > 0: + state["radii"].zero_() + torch.cuda.empty_cache() + + def _update_state( + self, + params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict], + state: Dict[str, Any], + info: Dict[str, Any], + packed: bool = False, + ): + for key in ["anchors2d", "width", "height", "n_cameras", "radii", "gaussian_ids"]: + assert key in info, f"{key} is required but missing." + + # normalize grads to [-1, 1] screen space + if self.absgrad: + grads = info["anchors2d"].absgrad.clone() + else: + grads = info["anchors2d"].grad.clone() + grads[..., 0] *= info["width"] / 2.0 * info["n_cameras"] + grads[..., 1] *= info["height"] / 2.0 * info["n_cameras"] + + # initialize state on the first run + n_gaussian = len(list(params["anchors"].shape[0])) + if state["grad2d"] is None: + state["grad2d"] = torch.zeros(n_gaussian, device=grads.device) + if state["count"] is None: + state["count"] = torch.zeros(n_gaussian, device=grads.device) + if self.refine_scale2d_stop_iter > 0 and state["radii"] is None: + assert "radii" in info, "radii is required but missing." + state["radii"] = torch.zeros(n_gaussian, device=grads.device) + + # update the running state + if packed: + # grads is [nnz, 2] + gs_ids = info["gaussian_ids"] # [nnz] + radii = info["radii"] # [nnz] + else: + # grads is [C, N, 2] + sel = info["radii"] > 0.0 # [C, N] + gs_ids = torch.where(sel)[1] # [nnz] + grads = grads[sel] # [nnz, 2] + radii = info["radii"][sel] # [nnz] + + state["grad2d"].index_add_(0, gs_ids, grads.norm(dim=-1)) + state["count"].index_add_( + 0, gs_ids, torch.ones_like(gs_ids, dtype=torch.float32) + ) + if self.refine_scale2d_stop_iter > 0: + # Should be ideally using scatter max + state["radii"][gs_ids] = torch.maximum( + state["radii"][gs_ids], + # normalize radii to [0, 1] screen space + radii / float(max(info["width"], info["height"])), + ) + + @torch.no_grad() + def _grow_anchors( + self, + params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict], + optimizers: Dict[str, torch.optim.Optimizer], + state: Dict[str, Any], + step: int, + ) -> Tuple[int, int]: + pass + + @torch.no_grad() + def _prune_gs( + self, + params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict], + optimizers: Dict[str, torch.optim.Optimizer], + state: Dict[str, Any], + step: int, + ) -> int: + is_prune = torch.sigmoid(params["opacities"].flatten()) < self.prune_opa + if step > self.reset_every: + is_too_big = ( + torch.exp(params["scales"]).max(dim=-1).values + > self.prune_scale3d * state["scene_scale"] + ) + # The official code also implements sreen-size pruning but + # it's actually not being used due to a bug: + # https://github.com/graphdeco-inria/gaussian-splatting/issues/123 + # We implement it here for completeness but set `refine_scale2d_stop_iter` + # to 0 by default to disable it. + if step < self.refine_scale2d_stop_iter: + is_too_big |= state["radii"] > self.prune_scale2d + + is_prune = is_prune | is_too_big + + n_prune = is_prune.sum().item() + if n_prune > 0: + remove(params=params, optimizers=optimizers, state=state, mask=is_prune) + + return n_prune