From 3c9dc0e36eb6765ee43ebd5ab4bebe784210e7e3 Mon Sep 17 00:00:00 2001 From: BaowenZ <947976219@qq.com> Date: Wed, 7 Aug 2024 14:36:48 +0800 Subject: [PATCH 1/2] add RaDeGS --- examples/datasets/colmap.py | 26 +- examples/simple_trainer_recon.py | 938 ++++++++++++++++++ extract_dtu.sh | 12 + gsplat/cuda/_wrapper.py | 256 ++++- gsplat/cuda/csrc/bindings.h | 76 +- gsplat/cuda/csrc/ext.cpp | 6 +- .../cuda/csrc/fully_fused_projection_bwd.cu | 16 +- .../cuda/csrc/fully_fused_projection_fwd.cu | 27 +- .../csrc/fully_fused_projection_packed_bwd.cu | 22 +- .../csrc/fully_fused_projection_packed_fwd.cu | 31 +- gsplat/cuda/csrc/persp_proj_bwd.cu | 13 +- gsplat/cuda/csrc/persp_proj_fwd.cu | 23 +- gsplat/cuda/csrc/rasterize_to_pixels_bwd.cu | 262 ++++- gsplat/cuda/csrc/rasterize_to_pixels_fwd.cu | 206 +++- gsplat/cuda/csrc/symeigen.cuh | 227 +++++ gsplat/cuda/csrc/utils.cuh | 213 +++- gsplat/rendering.py | 31 +- train_dtu.sh | 20 + 18 files changed, 2286 insertions(+), 119 deletions(-) create mode 100644 examples/simple_trainer_recon.py create mode 100644 extract_dtu.sh create mode 100644 gsplat/cuda/csrc/symeigen.cuh create mode 100644 train_dtu.sh diff --git a/examples/datasets/colmap.py b/examples/datasets/colmap.py index 22eacc2ac..254c1171c 100644 --- a/examples/datasets/colmap.py +++ b/examples/datasets/colmap.py @@ -138,9 +138,14 @@ def __init__( image_dir_suffix = "" colmap_image_dir = os.path.join(data_dir, "images") image_dir = os.path.join(data_dir, "images" + image_dir_suffix) - for d in [image_dir, colmap_image_dir]: - if not os.path.exists(d): - raise ValueError(f"Image folder {d} does not exist.") + + if not os.path.exists(colmap_image_dir): + raise ValueError(f"Image folder {colmap_image_dir} does not exist.") + if not os.path.exists(image_dir): + image_dir = colmap_image_dir + self.factor = factor + else: + self.factor = None # Downsampled images may have different names vs images used for COLMAP, # so we need to map between the two sorted lists of files. @@ -249,12 +254,22 @@ def __len__(self): def __getitem__(self, item: int) -> Dict[str, Any]: index = self.indices[item] - image = imageio.imread(self.parser.image_paths[index])[..., :3] + image = imageio.imread(self.parser.image_paths[index]) + if image.shape[-1]==4: + mask = image[...,-1] + image = image[...,:3] + else: + mask = None camera_id = self.parser.camera_ids[index] K = self.parser.Ks_dict[camera_id].copy() # undistorted K params = self.parser.params_dict[camera_id] camtoworlds = self.parser.camtoworlds[index] + if self.parser.factor is not None: + image = cv2.resize(image,self.parser.imsize_dict[camera_id]) + if mask is not None: + mask = cv2.resize(mask,self.parser.imsize_dict[camera_id]) + if len(params) > 0: # Images are distorted. Undistort them. mapx, mapy = ( @@ -274,11 +289,14 @@ def __getitem__(self, item: int) -> Dict[str, Any]: K[0, 2] -= x K[1, 2] -= y + + data = { "K": torch.from_numpy(K).float(), "camtoworld": torch.from_numpy(camtoworlds).float(), "image": torch.from_numpy(image).float(), "image_id": item, # the index of the image in the dataset + "mask": torch.from_numpy(mask).float(), } if self.load_depths: diff --git a/examples/simple_trainer_recon.py b/examples/simple_trainer_recon.py new file mode 100644 index 000000000..0db61f286 --- /dev/null +++ b/examples/simple_trainer_recon.py @@ -0,0 +1,938 @@ +import json +import math +import os +import time +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Tuple + +import imageio +import nerfview +import numpy as np +import torch +import torch.nn.functional as F +import tqdm +import tyro +import viser +from datasets.colmap import Dataset, Parser +from datasets.traj import generate_interpolated_path +from torch import Tensor +from torch.utils.tensorboard import SummaryWriter +from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure +from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity +from utils import AppearanceOptModule, CameraOptModule, knn, rgb_to_sh, set_random_seed + +from gsplat.rendering import rasterization +from gsplat.strategy import DefaultStrategy + +import open3d as o3d +import open3d.core as o3c + + +@dataclass +class Config: + # Disable viewer + disable_viewer: bool = False + # Path to the .pt file. If provide, it will skip training and render a video + ckpt: Optional[str] = None + + # Path to the Mip-NeRF 360 dataset + data_dir: str = "data/360_v2/garden" + # Downsample factor for the dataset + data_factor: int = 4 + # Directory to save results + result_dir: str = "results/garden" + # Every N images there is a test image + test_every: int = 8 + # Random crop size for training (experimental) + patch_size: Optional[int] = None + # A global scaler that applies to the scene size related parameters + global_scale: float = 1.0 + + # Port for the viewer server + port: int = 8080 + + # Batch size for training. Learning rates are scaled automatically + batch_size: int = 1 + # A global factor to scale the number of training steps + steps_scaler: float = 1.0 + + # Number of training steps + max_steps: int = 30_000 + # Steps to evaluate the model + eval_steps: List[int] = field(default_factory=lambda: [7_000, 30_000]) + # Steps to save the model + save_steps: List[int] = field(default_factory=lambda: [7_000, 30_000]) + + # Initialization strategy + init_type: str = "sfm" + # Initial number of GSs. Ignored if using sfm + init_num_pts: int = 100_000 + # Initial extent of GSs as a multiple of the camera extent. Ignored if using sfm + init_extent: float = 3.0 + # Degree of spherical harmonics + sh_degree: int = 3 + # Turn on another SH degree every this steps + sh_degree_interval: int = 1000 + # Initial opacity of GS + init_opa: float = 0.1 + # Initial scale of GS + init_scale: float = 1.0 + # Weight for SSIM loss + ssim_lambda: float = 0.2 + + # Near plane clipping distance + near_plane: float = 0.01 + # Far plane clipping distance + far_plane: float = 1e10 + + # GSs with opacity below this value will be pruned + prune_opa: float = 0.005 + # GSs with image plane gradient above this value will be split/duplicated + grow_grad2d: float = 0.0002 + # GSs with scale below this value will be duplicated. Above will be split + grow_scale3d: float = 0.01 + # GSs with scale above this value will be pruned. + prune_scale3d: float = 0.1 + + # Start refining GSs after this iteration + refine_start_iter: int = 500 + # Stop refining GSs after this iteration + refine_stop_iter: int = 15_000 + # Reset opacities every this steps + reset_every: int = 3000 + # Refine GSs every this steps + refine_every: int = 100 + + # Use packed mode for rasterization, this leads to less memory usage but slightly slower. + packed: bool = False + # Use sparse gradients for optimization. (experimental) + sparse_grad: bool = False + # Use absolute gradient for pruning. This typically requires larger --grow_grad2d, e.g., 0.0008 or 0.0006 + absgrad: bool = False + # Anti-aliasing in rasterization. Might slightly hurt quantitative metrics. + antialiased: bool = False + # Whether to use revised opacity heuristic from arXiv:2404.06109 (experimental) + revised_opacity: bool = False + + # Use random background for training to discourage transparency + random_bkgd: bool = False + + # Enable camera optimization. + pose_opt: bool = False + # Learning rate for camera optimization + pose_opt_lr: float = 1e-5 + # Regularization for camera optimization as weight decay + pose_opt_reg: float = 1e-6 + # Add noise to camera extrinsics. This is only to test the camera pose optimization. + pose_noise: float = 0.0 + + # Enable appearance optimization. (experimental) + app_opt: bool = False + # Appearance embedding dimension + app_embed_dim: int = 16 + # Learning rate for appearance optimization + app_opt_lr: float = 1e-3 + # Regularization for appearance optimization as weight decay + app_opt_reg: float = 1e-6 + + # Enable depth loss. (experimental) + depth_loss: bool = False + # Weight for depth loss + depth_lambda: float = 1e-2 + + # Dump information to tensorboard every this steps + tb_every: int = 100 + # Save training images to tensorboard + tb_save_image: bool = False + + normal_consistency_loss: bool = False + normal_consistency_lambda: float = 0.05 + + def adjust_steps(self, factor: float): + self.eval_steps = [int(i * factor) for i in self.eval_steps] + self.save_steps = [int(i * factor) for i in self.save_steps] + self.max_steps = int(self.max_steps * factor) + self.sh_degree_interval = int(self.sh_degree_interval * factor) + self.refine_start_iter = int(self.refine_start_iter * factor) + self.refine_stop_iter = int(self.refine_stop_iter * factor) + self.reset_every = int(self.reset_every * factor) + self.refine_every = int(self.refine_every * factor) + + +def create_splats_with_optimizers( + parser: Parser, + init_type: str = "sfm", + init_num_pts: int = 100_000, + init_extent: float = 3.0, + init_opacity: float = 0.1, + init_scale: float = 1.0, + scene_scale: float = 1.0, + sh_degree: int = 3, + sparse_grad: bool = False, + batch_size: int = 1, + feature_dim: Optional[int] = None, + device: str = "cuda", +) -> Tuple[torch.nn.ParameterDict, Dict[str, torch.optim.Optimizer]]: + if init_type == "sfm": + points = torch.from_numpy(parser.points).float() + rgbs = torch.from_numpy(parser.points_rgb / 255.0).float() + elif init_type == "random": + points = init_extent * scene_scale * (torch.rand((init_num_pts, 3)) * 2 - 1) + rgbs = torch.rand((init_num_pts, 3)) + else: + raise ValueError("Please specify a correct init_type: sfm or random") + + N = points.shape[0] + # Initialize the GS size to be the average dist of the 3 nearest neighbors + dist2_avg = (knn(points, 4)[:, 1:] ** 2).mean(dim=-1) # [N,] + dist_avg = torch.sqrt(dist2_avg) + scales = torch.log(dist_avg * init_scale).unsqueeze(-1).repeat(1, 3) # [N, 3] + quats = torch.rand((N, 4)) # [N, 4] + opacities = torch.logit(torch.full((N,), init_opacity)) # [N,] + + params = [ + # name, value, lr + ("means", torch.nn.Parameter(points), 1.6e-4 * scene_scale), + ("scales", torch.nn.Parameter(scales), 5e-3), + ("quats", torch.nn.Parameter(quats), 1e-3), + ("opacities", torch.nn.Parameter(opacities), 5e-2), + ] + + if feature_dim is None: + # color is SH coefficients. + colors = torch.zeros((N, (sh_degree + 1) ** 2, 3)) # [N, K, 3] + colors[:, 0, :] = rgb_to_sh(rgbs) + params.append(("sh0", torch.nn.Parameter(colors[:, :1, :]), 2.5e-3)) + params.append(("shN", torch.nn.Parameter(colors[:, 1:, :]), 2.5e-3 / 20)) + else: + # features will be used for appearance and view-dependent shading + features = torch.rand(N, feature_dim) # [N, feature_dim] + params.append(("features", torch.nn.Parameter(features), 2.5e-3)) + colors = torch.logit(rgbs) # [N, 3] + params.append(("colors", torch.nn.Parameter(colors), 2.5e-3)) + + splats = torch.nn.ParameterDict({n: v for n, v, _ in params}).to(device) + # Scale learning rate based on batch size, reference: + # https://www.cs.princeton.edu/~smalladi/blog/2024/01/22/SDEs-ScalingRules/ + # Note that this would not make the training exactly equivalent, see + # https://arxiv.org/pdf/2402.18824v1 + optimizers = { + name: (torch.optim.SparseAdam if sparse_grad else torch.optim.Adam)( + [{"params": splats[name], "lr": lr * math.sqrt(batch_size)}], + eps=1e-15 / math.sqrt(batch_size), + betas=(1 - batch_size * (1 - 0.9), 1 - batch_size * (1 - 0.999)), + ) + for name, _, lr in params + } + return splats, optimizers + + +class Runner: + """Engine for training and testing.""" + + def __init__(self, cfg: Config) -> None: + set_random_seed(42) + + self.cfg = cfg + self.device = "cuda" + + # Where to dump results. + os.makedirs(cfg.result_dir, exist_ok=True) + + # Setup output directories. + self.ckpt_dir = f"{cfg.result_dir}/ckpts" + os.makedirs(self.ckpt_dir, exist_ok=True) + self.stats_dir = f"{cfg.result_dir}/stats" + os.makedirs(self.stats_dir, exist_ok=True) + self.render_dir = f"{cfg.result_dir}/renders" + os.makedirs(self.render_dir, exist_ok=True) + + # Tensorboard + self.writer = SummaryWriter(log_dir=f"{cfg.result_dir}/tb") + + # Load data: Training data should contain initial points and colors. + self.parser = Parser( + data_dir=cfg.data_dir, + factor=cfg.data_factor, + normalize=True, + test_every=cfg.test_every, + ) + self.trainset = Dataset( + self.parser, + split="train", + patch_size=cfg.patch_size, + load_depths=cfg.depth_loss, + ) + self.valset = Dataset(self.parser, split="val") + self.scene_scale = self.parser.scene_scale * 1.1 * cfg.global_scale + print("Scene scale:", self.scene_scale) + + # Model + feature_dim = 32 if cfg.app_opt else None + self.splats, self.optimizers = create_splats_with_optimizers( + self.parser, + init_type=cfg.init_type, + init_num_pts=cfg.init_num_pts, + init_extent=cfg.init_extent, + init_opacity=cfg.init_opa, + init_scale=cfg.init_scale, + scene_scale=self.scene_scale, + sh_degree=cfg.sh_degree, + sparse_grad=cfg.sparse_grad, + batch_size=cfg.batch_size, + feature_dim=feature_dim, + device=self.device, + ) + print("Model initialized. Number of GS:", len(self.splats["means"])) + + # Densification Strategy + self.strategy = DefaultStrategy( + verbose=True, + scene_scale=self.scene_scale, + prune_opa=cfg.prune_opa, + grow_grad2d=cfg.grow_grad2d, + grow_scale3d=cfg.grow_scale3d, + prune_scale3d=cfg.prune_scale3d, + # refine_scale2d_stop_iter=4000, # splatfacto behavior + refine_start_iter=cfg.refine_start_iter, + refine_stop_iter=cfg.refine_stop_iter, + reset_every=cfg.reset_every, + refine_every=cfg.refine_every, + absgrad=cfg.absgrad, + revised_opacity=cfg.revised_opacity, + ) + self.strategy.check_sanity(self.splats, self.optimizers) + self.strategy_state = self.strategy.initialize_state() + + self.pose_optimizers = [] + if cfg.pose_opt: + self.pose_adjust = CameraOptModule(len(self.trainset)).to(self.device) + self.pose_adjust.zero_init() + self.pose_optimizers = [ + torch.optim.Adam( + self.pose_adjust.parameters(), + lr=cfg.pose_opt_lr * math.sqrt(cfg.batch_size), + weight_decay=cfg.pose_opt_reg, + ) + ] + + if cfg.pose_noise > 0.0: + self.pose_perturb = CameraOptModule(len(self.trainset)).to(self.device) + self.pose_perturb.random_init(cfg.pose_noise) + + self.app_optimizers = [] + if cfg.app_opt: + self.app_module = AppearanceOptModule( + len(self.trainset), feature_dim, cfg.app_embed_dim, cfg.sh_degree + ).to(self.device) + # initialize the last layer to be zero so that the initial output is zero. + torch.nn.init.zeros_(self.app_module.color_head[-1].weight) + torch.nn.init.zeros_(self.app_module.color_head[-1].bias) + self.app_optimizers = [ + torch.optim.Adam( + self.app_module.embeds.parameters(), + lr=cfg.app_opt_lr * math.sqrt(cfg.batch_size) * 10.0, + weight_decay=cfg.app_opt_reg, + ), + torch.optim.Adam( + self.app_module.color_head.parameters(), + lr=cfg.app_opt_lr * math.sqrt(cfg.batch_size), + ), + ] + + # Losses & Metrics. + self.ssim = StructuralSimilarityIndexMeasure(data_range=1.0).to(self.device) + self.psnr = PeakSignalNoiseRatio(data_range=1.0).to(self.device) + self.lpips = LearnedPerceptualImagePatchSimilarity(normalize=True).to( + self.device + ) + + # Viewer + if not self.cfg.disable_viewer: + self.server = viser.ViserServer(port=cfg.port, verbose=False) + self.viewer = nerfview.Viewer( + server=self.server, + render_fn=self._viewer_render_fn, + mode="training", + ) + + def rasterize_splats( + self, + camtoworlds: Tensor, + Ks: Tensor, + width: int, + height: int, + **kwargs, + ) -> Tuple[Tensor, Tensor, Dict]: + means = self.splats["means"] # [N, 3] + # quats = F.normalize(self.splats["quats"], dim=-1) # [N, 4] + # rasterization does normalization internally + quats = self.splats["quats"] # [N, 4] + scales = torch.exp(self.splats["scales"]) # [N, 3] + opacities = torch.sigmoid(self.splats["opacities"]) # [N,] + + image_ids = kwargs.pop("image_ids", None) + if self.cfg.app_opt: + colors = self.app_module( + features=self.splats["features"], + embed_ids=image_ids, + dirs=means[None, :, :] - camtoworlds[:, None, :3, 3], + sh_degree=kwargs.pop("sh_degree", self.cfg.sh_degree), + ) + colors = colors + self.splats["colors"] + colors = torch.sigmoid(colors) + else: + colors = torch.cat([self.splats["sh0"], self.splats["shN"]], 1) # [N, K, 3] + + rasterize_mode = "antialiased" if self.cfg.antialiased else "classic" + render_colors, render_alphas, expected_depths, median_depths, render_normals, info = rasterization( + means=means, + quats=quats, + scales=scales, + opacities=opacities, + colors=colors, + viewmats=torch.linalg.inv(camtoworlds), # [C, 4, 4] + Ks=Ks, # [C, 3, 3] + width=width, + height=height, + packed=self.cfg.packed, + absgrad=self.cfg.absgrad, + sparse_grad=self.cfg.sparse_grad, + rasterize_mode=rasterize_mode, + **kwargs, + ) + return render_colors, render_alphas, expected_depths, median_depths, render_normals, info + + def train(self): + cfg = self.cfg + device = self.device + + # Dump cfg. + with open(f"{cfg.result_dir}/cfg.json", "w") as f: + json.dump(vars(cfg), f) + + max_steps = cfg.max_steps + init_step = 0 + + schedulers = [ + # means has a learning rate schedule, that end at 0.01 of the initial value + torch.optim.lr_scheduler.ExponentialLR( + self.optimizers["means"], gamma=0.01 ** (1.0 / max_steps) + ), + ] + if cfg.pose_opt: + # pose optimization has a learning rate schedule + schedulers.append( + torch.optim.lr_scheduler.ExponentialLR( + self.pose_optimizers[0], gamma=0.01 ** (1.0 / max_steps) + ) + ) + + trainloader = torch.utils.data.DataLoader( + self.trainset, + batch_size=cfg.batch_size, + shuffle=True, + num_workers=4, + persistent_workers=True, + pin_memory=True, + ) + trainloader_iter = iter(trainloader) + + # Training loop. + global_tic = time.time() + pbar = tqdm.tqdm(range(init_step, max_steps)) + for step in pbar: + if not cfg.disable_viewer: + while self.viewer.state.status == "paused": + time.sleep(0.01) + self.viewer.lock.acquire() + tic = time.time() + + try: + data = next(trainloader_iter) + except StopIteration: + trainloader_iter = iter(trainloader) + data = next(trainloader_iter) + + camtoworlds = camtoworlds_gt = data["camtoworld"].to(device) # [1, 4, 4] + Ks = data["K"].to(device) # [1, 3, 3] + pixels = data["image"].to(device) / 255.0 # [1, H, W, 3] + num_train_rays_per_step = ( + pixels.shape[0] * pixels.shape[1] * pixels.shape[2] + ) + image_ids = data["image_id"].to(device) + if cfg.depth_loss: + points = data["points"].to(device) # [1, M, 2] + depths_gt = data["depths"].to(device) # [1, M] + + height, width = pixels.shape[1:3] + + if cfg.pose_noise: + camtoworlds = self.pose_perturb(camtoworlds, image_ids) + + if cfg.pose_opt: + camtoworlds = self.pose_adjust(camtoworlds, image_ids) + + # sh schedule + sh_degree_to_use = min(step // cfg.sh_degree_interval, cfg.sh_degree) + + rade = cfg.normal_consistency_loss and (step>=15000) + # forward + renders, alphas, expected_depths, median_depths, normals, info = self.rasterize_splats( + camtoworlds=camtoworlds, + Ks=Ks, + width=width, + height=height, + sh_degree=sh_degree_to_use, + near_plane=cfg.near_plane, + far_plane=cfg.far_plane, + image_ids=image_ids, + render_mode="RGB+ED" if cfg.depth_loss else "RGB", + require_rade = rade + ) + if renders.shape[-1] == 4: + colors, depths = renders[..., 0:3], renders[..., 3:4] + else: + colors, depths = renders, None + if renders.shape[-1] == 4: + colors = renders[..., 0:3] + else: + colors = renders + + if cfg.random_bkgd: + bkgd = torch.rand(1, 3, device=device) + colors = colors + bkgd * (1.0 - alphas) + + self.strategy.step_pre_backward( + params=self.splats, + optimizers=self.optimizers, + state=self.strategy_state, + step=step, + info=info, + ) + + # loss + l1loss = F.l1_loss(colors, pixels) + ssimloss = 1.0 - self.ssim( + pixels.permute(0, 3, 1, 2), colors.permute(0, 3, 1, 2) + ) + loss = l1loss * (1.0 - cfg.ssim_lambda) + ssimloss * cfg.ssim_lambda + if cfg.depth_loss: + # query depths from depth map + points = torch.stack( + [ + points[:, :, 0] / (width - 1) * 2 - 1, + points[:, :, 1] / (height - 1) * 2 - 1, + ], + dim=-1, + ) # normalize to [-1, 1] + grid = points.unsqueeze(2) # [1, M, 1, 2] + depths = F.grid_sample( + depths.permute(0, 3, 1, 2), grid, align_corners=True + ) # [1, 1, M, 1] + depths = depths.squeeze(3).squeeze(1) # [1, M] + # calculate loss in disparity space + disp = torch.where(depths > 0.0, 1.0 / depths, torch.zeros_like(depths)) + disp_gt = 1.0 / depths_gt # [1, M] + depthloss = F.l1_loss(disp, disp_gt) * self.scene_scale + loss += depthloss * cfg.depth_lambda + + if rade: + grid_x, grid_y = torch.meshgrid(torch.arange(width)+0.5, torch.arange(height)+0.5, indexing='xy') + points = torch.stack([grid_x, grid_y, torch.ones_like(grid_x)], dim=-1).reshape(1, -1, 3).float().cuda() + rays_d = points @ torch.linalg.inv(Ks.transpose(2,1)) # 1, M, 3 + points_e = expected_depths.reshape(Ks.shape[0],-1,1) * rays_d + points_m = median_depths.reshape(Ks.shape[0],-1,1) * rays_d + points_e = points_e.reshape_as(normals) + points_m = points_m.reshape_as(normals) + normal_map_e = torch.zeros_like(points_e) + dx = points_e[...,2:, 1:-1,:] - points_e[...,:-2, 1:-1,:] + dy = points_e[...,1:-1, 2:,:] - points_e[...,1:-1, :-2,:] + normal_map = torch.nn.functional.normalize(torch.cross(dx, dy, dim=-1), dim=-1) + normal_map_e[...,1:-1, 1:-1, :] = normal_map + normal_map_m = torch.zeros_like(points_m) + dx = points_m[...,2:, 1:-1,:] - points_m[...,:-2, 1:-1,:] + dy = points_m[...,1:-1, 2:,:] - points_m[...,1:-1, :-2,:] + normal_map = torch.nn.functional.normalize(torch.cross(dx, dy, dim=-1), dim=-1) + normal_map_m[...,1:-1, 1:-1, :] = normal_map + normal_error_map_e = (1 - (normals * normal_map_e).sum(dim=-1)) + normal_error_map_m = (1 - (normals * normal_map_m).sum(dim=-1)) + loss += cfg.normal_consistency_lambda * (0.4 * normal_error_map_e.mean() + 0.6 * normal_error_map_m.mean()) + + loss.backward() + + desc = f"loss={loss.item():.3f}| " f"sh degree={sh_degree_to_use}| " + if cfg.depth_loss: + desc += f"depth loss={depthloss.item():.6f}| " + if cfg.pose_opt and cfg.pose_noise: + # monitor the pose error if we inject noise + pose_err = F.l1_loss(camtoworlds_gt, camtoworlds) + desc += f"pose err={pose_err.item():.6f}| " + pbar.set_description(desc) + + if cfg.tb_every > 0 and step % cfg.tb_every == 0: + mem = torch.cuda.max_memory_allocated() / 1024**3 + self.writer.add_scalar("train/loss", loss.item(), step) + self.writer.add_scalar("train/l1loss", l1loss.item(), step) + self.writer.add_scalar("train/ssimloss", ssimloss.item(), step) + self.writer.add_scalar("train/num_GS", len(self.splats["means"]), step) + self.writer.add_scalar("train/mem", mem, step) + if cfg.depth_loss: + self.writer.add_scalar("train/depthloss", depthloss.item(), step) + if cfg.tb_save_image: + canvas = torch.cat([pixels, colors], dim=2).detach().cpu().numpy() + canvas = canvas.reshape(-1, *canvas.shape[2:]) + self.writer.add_image("train/render", canvas, step) + self.writer.flush() + + self.strategy.step_post_backward( + params=self.splats, + optimizers=self.optimizers, + state=self.strategy_state, + step=step, + info=info, + ) + + # Turn Gradients into Sparse Tensor before running optimizer + if cfg.sparse_grad: + assert cfg.packed, "Sparse gradients only work with packed mode." + gaussian_ids = info["gaussian_ids"] + for k in self.splats.keys(): + grad = self.splats[k].grad + if grad is None or grad.is_sparse: + continue + self.splats[k].grad = torch.sparse_coo_tensor( + indices=gaussian_ids[None], # [1, nnz] + values=grad[gaussian_ids], # [nnz, ...] + size=self.splats[k].size(), # [N, ...] + is_coalesced=len(Ks) == 1, + ) + + # optimize + for optimizer in self.optimizers.values(): + optimizer.step() + optimizer.zero_grad(set_to_none=True) + for optimizer in self.pose_optimizers: + optimizer.step() + optimizer.zero_grad(set_to_none=True) + for optimizer in self.app_optimizers: + optimizer.step() + optimizer.zero_grad(set_to_none=True) + for scheduler in schedulers: + scheduler.step() + + # save checkpoint + if step in [i - 1 for i in cfg.save_steps] or step == max_steps - 1: + mem = torch.cuda.max_memory_allocated() / 1024**3 + stats = { + "mem": mem, + "ellipse_time": time.time() - global_tic, + "num_GS": len(self.splats["means"]), + } + print("Step: ", step, stats) + with open(f"{self.stats_dir}/train_step{step:04d}.json", "w") as f: + json.dump(stats, f) + torch.save( + { + "step": step, + "splats": self.splats.state_dict(), + }, + f"{self.ckpt_dir}/ckpt_{step}.pt", + ) + + # eval the full set + if step in [i - 1 for i in cfg.eval_steps] or step == max_steps - 1: + self.eval(step) + self.render_traj(step) + + if not cfg.disable_viewer: + self.viewer.lock.release() + num_train_steps_per_sec = 1.0 / (time.time() - tic) + num_train_rays_per_sec = ( + num_train_rays_per_step * num_train_steps_per_sec + ) + # Update the viewer state. + self.viewer.state.num_train_rays_per_sec = num_train_rays_per_sec + # Update the scene. + self.viewer.update(step, num_train_rays_per_step) + + @torch.no_grad() + def eval(self, step: int): + """Entry for evaluation.""" + if len(self.valset) == 0: + return + print("Running evaluation...") + cfg = self.cfg + device = self.device + valloader = torch.utils.data.DataLoader( + self.valset, batch_size=1, shuffle=False, num_workers=1 + ) + ellipse_time = 0 + metrics = {"psnr": [], "ssim": [], "lpips": []} + for i, data in enumerate(valloader): + camtoworlds = data["camtoworld"].to(device) + Ks = data["K"].to(device) + pixels = data["image"].to(device) / 255.0 + height, width = pixels.shape[1:3] + + torch.cuda.synchronize() + tic = time.time() + colors, _, _, _, _, _ = self.rasterize_splats( + camtoworlds=camtoworlds, + Ks=Ks, + width=width, + height=height, + sh_degree=cfg.sh_degree, + near_plane=cfg.near_plane, + far_plane=cfg.far_plane, + ) # [1, H, W, 3] + colors = torch.clamp(colors, 0.0, 1.0) + torch.cuda.synchronize() + ellipse_time += time.time() - tic + + # write images + canvas = torch.cat([pixels, colors], dim=2).squeeze(0).cpu().numpy() + imageio.imwrite( + f"{self.render_dir}/val_{i:04d}.png", (canvas * 255).astype(np.uint8) + ) + + pixels = pixels.permute(0, 3, 1, 2) # [1, 3, H, W] + colors = colors.permute(0, 3, 1, 2) # [1, 3, H, W] + metrics["psnr"].append(self.psnr(colors, pixels)) + metrics["ssim"].append(self.ssim(colors, pixels)) + metrics["lpips"].append(self.lpips(colors, pixels)) + + ellipse_time /= len(valloader) + + psnr = torch.stack(metrics["psnr"]).mean() + ssim = torch.stack(metrics["ssim"]).mean() + lpips = torch.stack(metrics["lpips"]).mean() + print( + f"PSNR: {psnr.item():.3f}, SSIM: {ssim.item():.4f}, LPIPS: {lpips.item():.3f} " + f"Time: {ellipse_time:.3f}s/image " + f"Number of GS: {len(self.splats['means'])}" + ) + # save stats as json + stats = { + "psnr": psnr.item(), + "ssim": ssim.item(), + "lpips": lpips.item(), + "ellipse_time": ellipse_time, + "num_GS": len(self.splats["means"]), + } + with open(f"{self.stats_dir}/val_step{step:04d}.json", "w") as f: + json.dump(stats, f) + # save stats to tensorboard + for k, v in stats.items(): + self.writer.add_scalar(f"val/{k}", v, step) + self.writer.flush() + + @torch.no_grad() + def render_traj(self, step: int): + """Entry for trajectory rendering.""" + print("Running trajectory rendering...") + cfg = self.cfg + device = self.device + + camtoworlds = self.parser.camtoworlds[5:-5] + camtoworlds = generate_interpolated_path(camtoworlds, 1) # [N, 3, 4] + camtoworlds = np.concatenate( + [ + camtoworlds, + np.repeat(np.array([[[0.0, 0.0, 0.0, 1.0]]]), len(camtoworlds), axis=0), + ], + axis=1, + ) # [N, 4, 4] + + camtoworlds = torch.from_numpy(camtoworlds).float().to(device) + K = torch.from_numpy(list(self.parser.Ks_dict.values())[0][None]).float().to(device) + width, height = list(self.parser.imsize_dict.values())[0] + + canvas_all = [] + for i in tqdm.trange(len(camtoworlds), desc="Rendering trajectory"): + renders, _, expected_depths, median_depths, normals, _ = self.rasterize_splats( + camtoworlds=camtoworlds[i : i + 1], + Ks=K, + width=width, + height=height, + sh_degree=cfg.sh_degree, + near_plane=cfg.near_plane, + far_plane=cfg.far_plane, + render_mode="RGB", + require_rade = True, + ) # [1, H, W, 4] + colors = torch.clamp(renders[0, ..., 0:3], 0.0, 1.0) # [H, W, 3] + + grid_x, grid_y = torch.meshgrid(torch.arange(width)+0.5, torch.arange(height)+0.5, indexing='xy') + points = torch.stack([grid_x, grid_y, torch.ones_like(grid_x)], dim=-1).reshape(1, -1, 3).float().cuda() + rays_d = points @ torch.linalg.inv(K.transpose(2,1)) # 1, M, 3 + points_e = expected_depths.reshape(K.shape[0],-1,1) * rays_d + points_m = median_depths.reshape(K.shape[0],-1,1) * rays_d + points_e = points_e.reshape_as(normals) + points_m = points_m.reshape_as(normals) + normal_map_e = torch.zeros_like(points_e) + dx = points_e[...,2:, 1:-1,:] - points_e[...,:-2, 1:-1,:] + dy = points_e[...,1:-1, 2:,:] - points_e[...,1:-1, :-2,:] + normal_map = torch.nn.functional.normalize(torch.cross(dx, dy, dim=-1), dim=-1) + normal_map_e[...,1:-1, 1:-1, :] = normal_map + normal_map_m = torch.zeros_like(points_m) + dx = points_m[...,2:, 1:-1,:] - points_m[...,:-2, 1:-1,:] + dy = points_m[...,1:-1, 2:,:] - points_m[...,1:-1, :-2,:] + normal_map = torch.nn.functional.normalize(torch.cross(dx, dy, dim=-1), dim=-1) + normal_map_m[...,1:-1, 1:-1, :] = normal_map + # depths = renders[0, ..., 3:4] # [H, W, 1] + # expected_depths = expected_depths[0] + # expected_depths = (expected_depths - expected_depths.min()) / (expected_depths.max() - expected_depths.min()) + + # median_depths = median_depths[0] + # median_depths = (median_depths - median_depths.min()) / (median_depths.max() - median_depths.min()) + normal_map_m = normal_map_m[0] + normal_map_m = (-normal_map_m + 1) / 2 + + normal_map_e = normal_map_e[0] + normal_map_e = (-normal_map_e + 1) / 2 + + normals = normals[0] + normals = (-normals + 1) / 2 + + # write images + # canvas = torch.cat( + # [colors, depths.repeat(1, 1, 3), normals], dim=0 if width > height else 1 + # ) + canvas = torch.cat( + [ + torch.cat([colors, normals], dim=1), + torch.cat([normal_map_m, normal_map_e], dim=1) + ], dim=0) + + canvas = (canvas.cpu().numpy() * 255).astype(np.uint8) + canvas_all.append(canvas) + + # save to video + video_dir = f"{cfg.result_dir}/videos" + os.makedirs(video_dir, exist_ok=True) + writer = imageio.get_writer(f"{video_dir}/traj_{step}.mp4", fps=1) + for canvas in canvas_all: + writer.append_data(canvas) + writer.close() + print(f"Video saved to {video_dir}/traj_{step}.mp4") + + @torch.no_grad() + def recon(self, step: int): + """Entry for reconstrution.""" + print("Running reconstrution...") + cfg = self.cfg + device = self.device + + trainloader = torch.utils.data.DataLoader( + self.trainset, batch_size=1, shuffle=False, num_workers=1 + ) + + voxel_size = 0.0005 + o3d_device = o3d.core.Device("CPU:0") + vbg = o3d.t.geometry.VoxelBlockGrid(attr_names=('tsdf', 'weight', 'color'), + attr_dtypes=(o3c.float32, + o3c.float32, + o3c.float32), + attr_channels=((1), (1), (3)), + voxel_size=voxel_size, + block_resolution=16, + block_count=50000, + device=o3d_device) + for data in trainloader: + camtoworlds = data["camtoworld"].to(device) + Ks = data["K"].to(device) + pixels = data["image"].to(device) / 255.0 + height, width = pixels.shape[1:3] + renders, alphas, expected_depths, median_depths, normals, _ = self.rasterize_splats( + camtoworlds=camtoworlds, + Ks=Ks, + width=width, + height=height, + sh_degree=cfg.sh_degree, + near_plane=cfg.near_plane, + far_plane=cfg.far_plane, + require_rade = True, + ) # [1, H, W, 3] + + torch.cuda.empty_cache() + depth = median_depths + depth[alphas<0.5] = 0 + if data["mask"] is not None: + depth[data["mask"]/255<0.5] = 0 + depth = o3d.t.geometry.Image(depth[0,...,0].cpu().numpy()) + depth = depth.to(o3d_device) + color = o3d.t.geometry.Image(torch.clamp(renders, min=0, max=1.0)[0].cpu().numpy()) + color = color.to(o3d_device) + intrinsic = o3d.core.Tensor(Ks[0].cpu().numpy().astype(np.float64)) + extrinsic = o3d.core.Tensor(torch.linalg.inv(camtoworlds[0]).cpu().numpy().astype(np.float64)) + frustum_block_coords = vbg.compute_unique_block_coordinates( + depth, + intrinsic, + extrinsic, + 1.0, 8.0 + ) + vbg.integrate( + frustum_block_coords, + depth, + color, + intrinsic, + extrinsic, + 1.0, 8.0 + ) + + mesh = vbg.extract_triangle_mesh() + mesh.compute_vertex_normals() + mesh_dir = f"{cfg.result_dir}/mesh" + os.makedirs(mesh_dir, exist_ok=True) + o3d.io.write_triangle_mesh(f"{mesh_dir}/recon_{step}.ply",mesh.to_legacy()) + print("done!") + + + + + @torch.no_grad() + def _viewer_render_fn( + self, camera_state: nerfview.CameraState, img_wh: Tuple[int, int] + ): + """Callable function for the viewer.""" + W, H = img_wh + c2w = camera_state.c2w + K = camera_state.get_K(img_wh) + c2w = torch.from_numpy(c2w).float().to(self.device) + K = torch.from_numpy(K).float().to(self.device) + + render_colors, _, _, _, _ = self.rasterize_splats( + camtoworlds=c2w[None], + Ks=K[None], + width=W, + height=H, + sh_degree=self.cfg.sh_degree, # active all SH degrees + radius_clip=3.0, # skip GSs that have small image radius (in pixels) + ) # [1, H, W, 3] + return render_colors[0].cpu().numpy() + + +def main(cfg: Config): + runner = Runner(cfg) + + if cfg.ckpt is not None: + # run eval only + ckpt = torch.load(cfg.ckpt, map_location=runner.device) + for k in runner.splats.keys(): + runner.splats[k].data = ckpt["splats"][k] + # runner.eval(step=ckpt["step"]) + # runner.render_traj(step=ckpt["step"]) + runner.recon(step=ckpt["step"]) + else: + runner.train() + + if not cfg.disable_viewer: + print("Viewer running... Ctrl+C to exit.") + time.sleep(1000000) + + +if __name__ == "__main__": + cfg = tyro.cli(Config) + cfg.adjust_steps(cfg.steps_scaler) + main(cfg) diff --git a/extract_dtu.sh b/extract_dtu.sh new file mode 100644 index 000000000..4a67163a4 --- /dev/null +++ b/extract_dtu.sh @@ -0,0 +1,12 @@ +for i in 122 118 114 110 106 +do + python examples/simple_trainer_recon.py --eval_steps 100 2000 7000 15000 20000 25000 \ + --disable_viewer --data_factor 2 \ + --data_dir /media/super/data/dataset/dtu/DTU_mask/scan$i/ \ + --result_dir output//scan$i/ \ + --normal_consistency_loss \ + --app_opt \ + --test_every 1000000000 \ + --absgrad \ + --ckpt output/scan$i/ckpts/ckpt_29999.pt +done \ No newline at end of file diff --git a/gsplat/cuda/_wrapper.py b/gsplat/cuda/_wrapper.py index 2fa0c4937..647ebabf0 100644 --- a/gsplat/cuda/_wrapper.py +++ b/gsplat/cuda/_wrapper.py @@ -373,6 +373,9 @@ def rasterize_to_pixels( conics: Tensor, # [C, N, 3] or [nnz, 3] colors: Tensor, # [C, N, channels] or [nnz, channels] opacities: Tensor, # [C, N] or [nnz] + ray_ts: Tensor, # [C, N] or [nnz] + ray_planes: Tensor, # [C, N, 2] or [nnz, 2] + normals: Tensor, # [C, N, 3] or [nnz, 3] image_width: int, image_height: int, tile_size: int, @@ -380,9 +383,11 @@ def rasterize_to_pixels( flatten_ids: Tensor, # [n_isects] backgrounds: Optional[Tensor] = None, # [C, channels] masks: Optional[Tensor] = None, # [C, tile_height, tile_width] + Ks: Optional[Tensor] = None, packed: bool = False, absgrad: bool = False, -) -> Tuple[Tensor, Tensor]: + require_geo: bool = False, +) -> Tuple[Tensor, Tensor, Tensor, Tensor]: """Rasterizes Gaussians to pixels. Args: @@ -415,12 +420,20 @@ def rasterize_to_pixels( assert conics.shape == (nnz, 3), conics.shape assert colors.shape[0] == nnz, colors.shape assert opacities.shape == (nnz,), opacities.shape + if require_geo: + assert ray_ts.shape == nnz, means2d.shape + assert ray_planes.shape == (nnz, 2), conics.shape + assert normals.shape == (nnz, 3), colors.shape else: N = means2d.size(1) assert means2d.shape == (C, N, 2), means2d.shape assert conics.shape == (C, N, 3), conics.shape assert colors.shape[:2] == (C, N), colors.shape assert opacities.shape == (C, N), opacities.shape + if require_geo: + assert ray_ts.shape == (C, N), means2d.shape + assert ray_planes.shape == (C, N, 2), conics.shape + assert normals.shape == (C, N, 3), colors.shape if backgrounds is not None: assert backgrounds.shape == (C, colors.shape[-1]), backgrounds.shape backgrounds = backgrounds.contiguous() @@ -482,25 +495,45 @@ def rasterize_to_pixels( assert ( tile_width * tile_size >= image_width ), f"Assert Failed: {tile_width} * {tile_size} >= {image_width}" - - render_colors, render_alphas = _RasterizeToPixels.apply( - means2d.contiguous(), - conics.contiguous(), - colors.contiguous(), - opacities.contiguous(), - backgrounds, - masks, - image_width, - image_height, - tile_size, - isect_offsets.contiguous(), - flatten_ids.contiguous(), - absgrad, - ) + if require_geo: + render_colors, render_alphas, expected_depths, median_depths, expected_normals = _RasterizeToPixels_wDepth.apply( + means2d.contiguous(), + conics.contiguous(), + colors.contiguous(), + opacities.contiguous(), + ray_ts.contiguous(), + ray_planes.contiguous(), + normals.contiguous(), + backgrounds, + masks, + image_width, + image_height, + tile_size, + Ks.contiguous(), + isect_offsets.contiguous(), + flatten_ids.contiguous(), + absgrad, + ) + else: + render_colors, render_alphas = _RasterizeToPixels.apply( + means2d.contiguous(), + conics.contiguous(), + colors.contiguous(), + opacities.contiguous(), + backgrounds, + masks, + image_width, + image_height, + tile_size, + isect_offsets.contiguous(), + flatten_ids.contiguous(), + absgrad, + ) + expected_depths, median_depths, expected_normals = None, None, None if padded_channels > 0: render_colors = render_colors[..., :-padded_channels] - return render_colors, render_alphas + return render_colors, render_alphas, expected_depths, median_depths, expected_normals @torch.no_grad() @@ -715,7 +748,7 @@ def forward( calc_compensations: bool, ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: # "covars" and {"quats", "scales"} are mutually exclusive - radii, means2d, depths, conics, compensations = _make_lazy_cuda_func( + radii, means2d, depths, conics, compensations, ray_ts, ray_planes, normals = _make_lazy_cuda_func( "fully_fused_projection_fwd" )( means, @@ -741,10 +774,10 @@ def forward( ctx.height = height ctx.eps2d = eps2d - return radii, means2d, depths, conics, compensations + return radii, means2d, depths, conics, compensations, ray_ts, ray_planes, normals @staticmethod - def backward(ctx, v_radii, v_means2d, v_depths, v_conics, v_compensations): + def backward(ctx, v_radii, v_means2d, v_depths, v_conics, v_compensations, v_ray_ts, v_ray_planes, v_normals): ( means, covars, @@ -780,6 +813,9 @@ def backward(ctx, v_radii, v_means2d, v_depths, v_conics, v_compensations): v_depths.contiguous(), v_conics.contiguous(), v_compensations, + v_ray_ts.contiguous(), + v_ray_planes.contiguous(), + v_normals.contiguous(), ctx.needs_input_grad[4], # viewmats_requires_grad ) if not ctx.needs_input_grad[0]: @@ -829,7 +865,7 @@ def forward( absgrad: bool, ) -> Tuple[Tensor, Tensor]: render_colors, render_alphas, last_ids = _make_lazy_cuda_func( - "rasterize_to_pixels_fwd" + "rasterize_to_pixels_wo_depth_fwd" )( means2d, conics, @@ -894,7 +930,7 @@ def backward( v_conics, v_colors, v_opacities, - ) = _make_lazy_cuda_func("rasterize_to_pixels_bwd")( + ) = _make_lazy_cuda_func("rasterize_to_pixels_wo_depth_bwd")( means2d, conics, colors, @@ -937,6 +973,171 @@ def backward( None, None, ) + +class _RasterizeToPixels_wDepth(torch.autograd.Function): + """Rasterize gaussians""" + + @staticmethod + def forward( + ctx, + means2d: Tensor, # [C, N, 2] + conics: Tensor, # [C, N, 3] + colors: Tensor, # [C, N, D] + opacities: Tensor, # [C, N] + ray_ts: Tensor, # [C, N] + ray_planes: Tensor, # [C, N, 2] + normals: Tensor, # [C, N, 3] + backgrounds: Tensor, # [C, D], Optional + masks: Tensor, # [C, tile_height, tile_width], Optional + width: int, + height: int, + tile_size: int, + Ks: Tensor, # [C, 3, 3] + isect_offsets: Tensor, # [C, tile_height, tile_width] + flatten_ids: Tensor, # [n_isects] + absgrad: bool, + ) -> Tuple[Tensor, Tensor]: + render_colors, render_alphas, expected_depths, median_depths, expected_normals, median_ids, last_ids = _make_lazy_cuda_func( + "rasterize_to_pixels_w_depth_fwd" + )( + means2d, + conics, + colors, + opacities, + ray_ts, + ray_planes, + normals, + backgrounds, + masks, + width, + height, + tile_size, + Ks, + isect_offsets, + flatten_ids, + ) + + ctx.save_for_backward( + means2d, + conics, + colors, + opacities, + ray_ts, + ray_planes, + normals, + backgrounds, + masks, + isect_offsets, + flatten_ids, + render_alphas, + last_ids, + median_ids, + Ks + ) + ctx.width = width + ctx.height = height + ctx.tile_size = tile_size + ctx.absgrad = absgrad + + # double to float + render_alphas = render_alphas.float() + return render_colors, render_alphas, expected_depths, median_depths, expected_normals + + @staticmethod + def backward( + ctx, + v_render_colors: Tensor, # [C, H, W, 3] + v_render_alphas: Tensor, # [C, H, W, 1] + v_render_expected_depths: Tensor, # [C, H, W, 1] + v_render_median_depths: Tensor, # [C, H, W, 1] + v_render_expected_normals: Tensor, # [C, H, W, 1] + ): + ( + means2d, + conics, + colors, + opacities, + ray_ts, + ray_planes, + normals, + backgrounds, + masks, + isect_offsets, + flatten_ids, + render_alphas, + last_ids, + median_ids, + Ks + ) = ctx.saved_tensors + width = ctx.width + height = ctx.height + tile_size = ctx.tile_size + absgrad = ctx.absgrad + # print(means2d.dtype,conics.dtype,colors.dtype,opacities.dtype,ray_ts.dtype,ray_planes.dtype,normals.dtype,v_render_expected_depths.dtype,v_render_expected_normals.dtype) + ( + v_means2d_abs, + v_means2d, + v_conics, + v_colors, + v_opacities, + v_ray_ts, + v_ray_planes, + v_normals + ) = _make_lazy_cuda_func("rasterize_to_pixels_w_depth_bwd")( + means2d, + conics, + colors, + opacities, + ray_ts, + ray_planes, + normals, + backgrounds, + masks, + width, + height, + tile_size, + Ks, + isect_offsets, + flatten_ids, + render_alphas, + last_ids, + median_ids, + v_render_colors.contiguous(), + v_render_alphas.contiguous(), + v_render_expected_depths.contiguous(), + v_render_median_depths.contiguous(), + v_render_expected_normals.contiguous(), + absgrad, + ) + + if absgrad: + means2d.absgrad = v_means2d_abs + + if ctx.needs_input_grad[7]: + v_backgrounds = (v_render_colors * (1.0 - render_alphas).float()).sum( + dim=(1, 2) + ) + else: + v_backgrounds = None + + return ( + v_means2d, + v_conics, + v_colors, + v_opacities, + v_ray_ts, + v_ray_planes, + v_normals, + v_backgrounds, + None, + None, + None, + None, + None, + None, + None, + None + ) class _FullyFusedProjectionPacked(torch.autograd.Function): @@ -969,6 +1170,9 @@ def forward( depths, conics, compensations, + ray_ts, + ray_planes, + normals ) = _make_lazy_cuda_func("fully_fused_projection_packed_fwd")( means, covars, # optional @@ -1003,7 +1207,7 @@ def forward( ctx.eps2d = eps2d ctx.sparse_grad = sparse_grad - return camera_ids, gaussian_ids, radii, means2d, depths, conics, compensations + return camera_ids, gaussian_ids, radii, means2d, depths, conics, compensations, ray_ts, ray_planes, normals @staticmethod def backward( @@ -1015,6 +1219,9 @@ def backward( v_depths, v_conics, v_compensations, + v_ray_ts, + v_ray_planes, + v_normals ): ( camera_ids, @@ -1055,6 +1262,9 @@ def backward( v_depths.contiguous(), v_conics.contiguous(), v_compensations, + v_ray_ts.contiguous(), + v_ray_planes.contiguous(), + v_normals.contiguous(), ctx.needs_input_grad[4], # viewmats_requires_grad sparse_grad, ) diff --git a/gsplat/cuda/csrc/bindings.h b/gsplat/cuda/csrc/bindings.h index faf28bad0..01fbeb41d 100644 --- a/gsplat/cuda/csrc/bindings.h +++ b/gsplat/cuda/csrc/bindings.h @@ -47,7 +47,9 @@ persp_proj_bwd_tensor(const torch::Tensor &means, // [C, N, 3] const torch::Tensor &Ks, // [C, 3, 3] const uint32_t width, const uint32_t height, const torch::Tensor &v_means2d, // [C, N, 2] - const torch::Tensor &v_covars2d // [C, N, 2, 2] + const torch::Tensor &v_covars2d, // [C, N, 2, 2] + const torch::Tensor &v_ray_planes, // [C, N, 2] + const torch::Tensor &v_normals // [C, N, 3] ); std::tuple @@ -65,7 +67,7 @@ world_to_cam_bwd_tensor(const torch::Tensor &means, // [N, 3] const bool means_requires_grad, const bool covars_requires_grad, const bool viewmats_requires_grad); -std::tuple +std::tuple fully_fused_projection_fwd_tensor( const torch::Tensor &means, // [N, 3] const at::optional &covars, // [N, 6] optional @@ -96,6 +98,9 @@ fully_fused_projection_bwd_tensor( const torch::Tensor &v_depths, // [C, N] const torch::Tensor &v_conics, // [C, N, 3] const at::optional &v_compensations, // [C, N] optional + const torch::Tensor &v_ray_ts, // [C, N] + const torch::Tensor &v_ray_planes, // [C, N, 2] + const torch::Tensor &v_normals, // [C, N, 3] const bool viewmats_requires_grad); std::tuple @@ -112,7 +117,8 @@ torch::Tensor isect_offset_encode_tensor(const torch::Tensor &isect_ids, // [n_i const uint32_t C, const uint32_t tile_width, const uint32_t tile_height); -std::tuple rasterize_to_pixels_fwd_tensor( +std::tuple +rasterize_to_pixels_wo_depth_fwd_tensor( // Gaussian parameters const torch::Tensor &means2d, // [C, N, 2] const torch::Tensor &conics, // [C, N, 3] @@ -127,17 +133,68 @@ std::tuple rasterize_to_pixels_fwd_ const torch::Tensor &flatten_ids // [n_isects] ); -std::tuple -rasterize_to_pixels_bwd_tensor( +std::tuple +rasterize_to_pixels_w_depth_fwd_tensor( // Gaussian parameters const torch::Tensor &means2d, // [C, N, 2] const torch::Tensor &conics, // [C, N, 3] - const torch::Tensor &colors, // [C, N, 3] + const torch::Tensor &colors, // [C, N, D] const torch::Tensor &opacities, // [N] - const at::optional &backgrounds, // [C, 3] + const torch::Tensor &ray_ts, // [C, N] + const torch::Tensor &ray_planes, // [C, N, 2] + const torch::Tensor &normals, // [C, N, 3] + const at::optional &backgrounds, // [C, D] const at::optional &mask, // [C, tile_height, tile_width] // image size const uint32_t image_width, const uint32_t image_height, const uint32_t tile_size, + const torch::Tensor &Ks, // [C, 3, 3] + // intersections + const torch::Tensor &tile_offsets, // [C, tile_height, tile_width] + const torch::Tensor &flatten_ids // [n_isects] +); + +std::tuple +rasterize_to_pixels_w_depth_bwd_tensor( + // Gaussian parameters + const torch::Tensor &means2d, // [C, N, 2] or [nnz, 2] + const torch::Tensor &conics, // [C, N, 3] or [nnz, 3] + const torch::Tensor &colors, // [C, N, 3] or [nnz, 3] + const torch::Tensor &opacities, // [C, N] or [nnz] + const torch::Tensor &ray_ts, // [C, N] or [nnz] + const torch::Tensor &ray_planes, // [C, N, 2] or [nnz, 2] + const torch::Tensor &normals, // [C, N, 3] or [nnz, 3] + const at::optional &backgrounds, // [C, 3] + const at::optional &masks, // [C, tile_height, tile_width] + // image size + const uint32_t image_width, const uint32_t image_height, const uint32_t tile_size, + const torch::Tensor &Ks, + // intersections + const torch::Tensor &tile_offsets, // [C, tile_height, tile_width] + const torch::Tensor &flatten_ids, // [n_isects] + // forward outputs + const torch::Tensor &render_alphas, // [C, image_height, image_width, 1] + const torch::Tensor &last_ids, // [C, image_height, image_width] + const torch::Tensor &median_ids, // [C, image_height, image_width] + // gradients of outputs + const torch::Tensor &v_render_colors, // [C, image_height, image_width, 3] + const torch::Tensor &v_render_alphas, // [C, image_height, image_width, 1] + const torch::Tensor &v_render_expected_depths, // [C, image_height, image_width, 1] + const torch::Tensor &v_render_median_depths, // [C, image_height, image_width, 1] + const torch::Tensor &v_render_expected_normals, // [C, image_height, image_width, 3] + // options + bool absgrad); + +std::tuple +rasterize_to_pixels_wo_depth_bwd_tensor( + // Gaussian parameters + const torch::Tensor &means2d, // [C, N, 2] or [nnz, 2] + const torch::Tensor &conics, // [C, N, 3] or [nnz, 3] + const torch::Tensor &colors, // [C, N, 3] or [nnz, 3] + const torch::Tensor &opacities, // [C, N] or [nnz] + const at::optional &backgrounds, // [C, 3] + const at::optional &masks, // [C, tile_height, tile_width] + // image size + const uint32_t image_width, const uint32_t image_height, const uint32_t tile_size, // intersections const torch::Tensor &tile_offsets, // [C, tile_height, tile_width] const torch::Tensor &flatten_ids, // [n_isects] @@ -181,7 +238,7 @@ compute_sh_bwd_tensor(const uint32_t K, const uint32_t degrees_to_use, * Packed Version ****************************************************************************************/ std::tuple + torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> fully_fused_projection_packed_fwd_tensor( const torch::Tensor &means, // [N, 3] const at::optional &covars, // [N, 6] @@ -213,6 +270,9 @@ fully_fused_projection_packed_bwd_tensor( const torch::Tensor &v_depths, // [nnz] const torch::Tensor &v_conics, // [nnz, 3] const at::optional &v_compensations, // [nnz] optional + const torch::Tensor &v_ray_ts, // [nnz] + const torch::Tensor &v_ray_planes, // [nnz, 2] + const torch::Tensor &v_normals, // [nnz, 3] const bool viewmats_requires_grad, const bool sparse_grad); std::tuple diff --git a/gsplat/cuda/csrc/ext.cpp b/gsplat/cuda/csrc/ext.cpp index 9dec63597..a54e77901 100644 --- a/gsplat/cuda/csrc/ext.cpp +++ b/gsplat/cuda/csrc/ext.cpp @@ -20,8 +20,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("isect_tiles", &isect_tiles_tensor); m.def("isect_offset_encode", &isect_offset_encode_tensor); - m.def("rasterize_to_pixels_fwd", &rasterize_to_pixels_fwd_tensor); - m.def("rasterize_to_pixels_bwd", &rasterize_to_pixels_bwd_tensor); + m.def("rasterize_to_pixels_w_depth_fwd", &rasterize_to_pixels_w_depth_fwd_tensor); + m.def("rasterize_to_pixels_wo_depth_fwd", &rasterize_to_pixels_wo_depth_fwd_tensor); + m.def("rasterize_to_pixels_w_depth_bwd", &rasterize_to_pixels_w_depth_bwd_tensor); + m.def("rasterize_to_pixels_wo_depth_bwd", &rasterize_to_pixels_wo_depth_bwd_tensor); m.def("rasterize_to_indices_in_range", &rasterize_to_indices_in_range_tensor); diff --git a/gsplat/cuda/csrc/fully_fused_projection_bwd.cu b/gsplat/cuda/csrc/fully_fused_projection_bwd.cu index 0f8e16ba5..bd32aa24a 100644 --- a/gsplat/cuda/csrc/fully_fused_projection_bwd.cu +++ b/gsplat/cuda/csrc/fully_fused_projection_bwd.cu @@ -35,6 +35,9 @@ __global__ void fully_fused_projection_bwd_kernel( const T *__restrict__ v_depths, // [C, N] const T *__restrict__ v_conics, // [C, N, 3] const T *__restrict__ v_compensations, // [C, N] optional + const T *__restrict__ v_ray_ts, // [C, N] + const T *__restrict__ v_ray_planes, // [C, N, 2] + const T *__restrict__ v_normals, // [C, N, 3] // grad inputs T *__restrict__ v_means, // [N, 3] T *__restrict__ v_covars, // [N, 6] optional @@ -60,6 +63,9 @@ __global__ void fully_fused_projection_bwd_kernel( v_means2d += idx * 2; v_depths += idx; v_conics += idx * 3; + v_ray_ts += idx; + v_ray_planes += idx * 2; + v_normals += idx * 3; // vjp: compute the inverse of the 2d covariance mat2 covar2d_inv = mat2(conics[0], conics[1], conics[1], conics[2]); @@ -67,6 +73,9 @@ __global__ void fully_fused_projection_bwd_kernel( mat2(v_conics[0], v_conics[1] * .5f, v_conics[1] * .5f, v_conics[2]); mat2 v_covar2d(0.f); inverse_vjp(covar2d_inv, v_covar2d_inv, v_covar2d); + T v_ray_t = v_ray_ts[0]; + vec2 v_ray_plane = {v_ray_planes[0], v_ray_planes[1]}; + vec3 v_normal = {v_normals[0], v_normals[1], v_normals[2]}; if (v_compensations != nullptr) { // vjp: compensation term @@ -107,10 +116,11 @@ __global__ void fully_fused_projection_bwd_kernel( mat3 v_covar_c(0.f); vec3 v_mean_c(0.f); persp_proj_vjp(mean_c, covar_c, fx, fy, cx, cy, image_width, image_height, - v_covar2d, glm::make_vec2(v_means2d), v_mean_c, v_covar_c); + v_covar2d, glm::make_vec2(v_means2d), v_ray_plane, v_normal, v_mean_c, v_covar_c); // add contribution from v_depths v_mean_c.z += v_depths[0]; + v_mean_c += v_ray_ts[0] * glm::normalize(mean_c); // vjp: transform Gaussian covariance to camera space vec3 v_mean(0.f); @@ -203,6 +213,9 @@ fully_fused_projection_bwd_tensor( const torch::Tensor &v_depths, // [C, N] const torch::Tensor &v_conics, // [C, N, 3] const at::optional &v_compensations, // [C, N] optional + const torch::Tensor &v_ray_ts, // [C, N] + const torch::Tensor &v_ray_planes, // [C, N, 2] + const torch::Tensor &v_normals, // [C, N, 3] const bool viewmats_requires_grad) { DEVICE_GUARD(means); CHECK_INPUT(means); @@ -258,6 +271,7 @@ fully_fused_projection_bwd_tensor( v_conics.data_ptr(), v_compensations.has_value() ? v_compensations.value().data_ptr() : nullptr, + v_ray_ts.data_ptr(), v_ray_planes.data_ptr(), v_normals.data_ptr(), v_means.data_ptr(), covars.has_value() ? v_covars.data_ptr() : nullptr, covars.has_value() ? nullptr : v_quats.data_ptr(), diff --git a/gsplat/cuda/csrc/fully_fused_projection_fwd.cu b/gsplat/cuda/csrc/fully_fused_projection_fwd.cu index 420f78fdf..b378bd2cf 100644 --- a/gsplat/cuda/csrc/fully_fused_projection_fwd.cu +++ b/gsplat/cuda/csrc/fully_fused_projection_fwd.cu @@ -31,7 +31,10 @@ fully_fused_projection_fwd_kernel(const uint32_t C, const uint32_t N, T *__restrict__ means2d, // [C, N, 2] T *__restrict__ depths, // [C, N] T *__restrict__ conics, // [C, N, 3] - T *__restrict__ compensations // [C, N] optional + T *__restrict__ compensations, // [C, N] optional + T *__restrict__ ray_ts, // [C, N] optional + T *__restrict__ ray_planes, // [C, N] optional + T *__restrict__ normals // [C, N] optional ) { // parallelize over C * N. uint32_t idx = cg::this_grid().thread_rank(); @@ -61,6 +64,8 @@ fully_fused_projection_fwd_kernel(const uint32_t C, const uint32_t N, return; } + T ray_t = glm::length(mean_c); + // transform Gaussian covariance to camera space mat3 covar; if (covars != nullptr) { @@ -82,8 +87,10 @@ fully_fused_projection_fwd_kernel(const uint32_t C, const uint32_t N, // perspective projection mat2 covar2d; vec2 mean2d; + vec2 ray_plane; + vec3 normal; persp_proj(mean_c, covar_c, Ks[0], Ks[4], Ks[2], Ks[5], image_width, - image_height, covar2d, mean2d); + image_height, covar2d, mean2d, ray_plane, normal); T compensation; T det = add_blur(eps2d, covar2d, compensation); @@ -126,9 +133,15 @@ fully_fused_projection_fwd_kernel(const uint32_t C, const uint32_t N, if (compensations != nullptr) { compensations[idx] = compensation; } + ray_ts[idx] = ray_t; + ray_planes[idx * 2] = ray_plane.x; + ray_planes[idx * 2 + 1] = ray_plane.y; + normals[idx * 3] = normal.x; + normals[idx * 3 + 1] = normal.y; + normals[idx * 3 + 2] = normal.z; } -std::tuple +std::tuple fully_fused_projection_fwd_tensor( const torch::Tensor &means, // [N, 3] const at::optional &covars, // [N, 6] optional @@ -164,6 +177,9 @@ fully_fused_projection_fwd_tensor( // we dont want NaN to appear in this tensor, so we zero intialize it compensations = torch::zeros({C, N}, means.options()); } + torch::Tensor ray_ts = torch::empty({C, N}, means.options()); + torch::Tensor ray_planes = torch::empty({C, N, 2}, means.options()); + torch::Tensor normals = torch::empty({C, N, 3}, means.options()); if (C && N) { fully_fused_projection_fwd_kernel <<<(C * N + N_THREADS - 1) / N_THREADS, N_THREADS, 0, stream>>>( @@ -175,7 +191,8 @@ fully_fused_projection_fwd_tensor( image_height, eps2d, near_plane, far_plane, radius_clip, radii.data_ptr(), means2d.data_ptr(), depths.data_ptr(), conics.data_ptr(), - calc_compensations ? compensations.data_ptr() : nullptr); + calc_compensations ? compensations.data_ptr() : nullptr, + ray_ts.data_ptr(),ray_planes.data_ptr(),normals.data_ptr()); } - return std::make_tuple(radii, means2d, depths, conics, compensations); + return std::make_tuple(radii, means2d, depths, conics, compensations, ray_ts, ray_planes, normals); } diff --git a/gsplat/cuda/csrc/fully_fused_projection_packed_bwd.cu b/gsplat/cuda/csrc/fully_fused_projection_packed_bwd.cu index 476390814..88a2f2af2 100644 --- a/gsplat/cuda/csrc/fully_fused_projection_packed_bwd.cu +++ b/gsplat/cuda/csrc/fully_fused_projection_packed_bwd.cu @@ -36,6 +36,9 @@ __global__ void fully_fused_projection_packed_bwd_kernel( const T *__restrict__ v_depths, // [nnz] const T *__restrict__ v_conics, // [nnz, 3] const T *__restrict__ v_compensations, // [nnz] optional + const T *__restrict__ v_ray_ts, // [nnz] + const T *__restrict__ v_ray_planes, // [nnz, 2] + const T *__restrict__ v_normals, // [nnz, 3] const bool sparse_grad, // whether the outputs are in COO format [nnz, ...] // grad inputs T *__restrict__ v_means, // [N, 3] or [nnz, 3] @@ -62,6 +65,9 @@ __global__ void fully_fused_projection_packed_bwd_kernel( v_means2d += idx * 2; v_depths += idx; v_conics += idx * 3; + v_ray_ts += idx; + v_ray_planes += idx * 2; + v_normals += idx * 3; // vjp: compute the inverse of the 2d covariance mat2 covar2d_inv = mat2(conics[0], conics[1], conics[1], conics[2]); @@ -69,6 +75,9 @@ __global__ void fully_fused_projection_packed_bwd_kernel( mat2(v_conics[0], v_conics[1] * .5f, v_conics[1] * .5f, v_conics[2]); mat2 v_covar2d(0.f); inverse_vjp(covar2d_inv, v_covar2d_inv, v_covar2d); + T v_ray_t = v_ray_ts[0]; + vec2 v_ray_plane = {v_ray_planes[0], v_ray_planes[1]}; + vec3 v_normal = {v_normals[0], v_normals[1], v_normals[2]}; if (v_compensations != nullptr) { // vjp: compensation term @@ -109,10 +118,12 @@ __global__ void fully_fused_projection_packed_bwd_kernel( mat3 v_covar_c(0.f); vec3 v_mean_c(0.f); persp_proj_vjp(mean_c, covar_c, fx, fy, cx, cy, image_width, image_height, - v_covar2d, glm::make_vec2(v_means2d), v_mean_c, v_covar_c); + v_covar2d, glm::make_vec2(v_means2d), v_ray_plane, v_normal, + v_mean_c, v_covar_c); // add contribution from v_depths v_mean_c.z += v_depths[0]; + v_mean_c += v_ray_ts[0] * glm::normalize(mean_c); // vjp: transform Gaussian covariance to camera space vec3 v_mean(0.f); @@ -242,6 +253,9 @@ fully_fused_projection_packed_bwd_tensor( const torch::Tensor &v_depths, // [nnz] const torch::Tensor &v_conics, // [nnz, 3] const at::optional &v_compensations, // [nnz] optional + const torch::Tensor &v_ray_ts, // [nnz] + const torch::Tensor &v_ray_planes, // [nnz, 2] + const torch::Tensor &v_normals, // [nnz, 3] const bool viewmats_requires_grad, const bool sparse_grad) { DEVICE_GUARD(means); CHECK_INPUT(means); @@ -260,6 +274,9 @@ fully_fused_projection_packed_bwd_tensor( CHECK_INPUT(v_means2d); CHECK_INPUT(v_depths); CHECK_INPUT(v_conics); + CHECK_INPUT(v_ray_ts); + CHECK_INPUT(v_ray_planes); + CHECK_INPUT(v_normals); if (compensations.has_value()) { CHECK_INPUT(compensations.value()); } @@ -312,6 +329,9 @@ fully_fused_projection_packed_bwd_tensor( v_conics.data_ptr(), v_compensations.has_value() ? v_compensations.value().data_ptr() : nullptr, + v_ray_ts.data_ptr(), + v_ray_planes.data_ptr(), + v_normals.data_ptr(), sparse_grad, v_means.data_ptr(), covars.has_value() ? v_covars.data_ptr() : nullptr, covars.has_value() ? nullptr : v_quats.data_ptr(), diff --git a/gsplat/cuda/csrc/fully_fused_projection_packed_fwd.cu b/gsplat/cuda/csrc/fully_fused_projection_packed_fwd.cu index 7f7082f6e..91debabc2 100644 --- a/gsplat/cuda/csrc/fully_fused_projection_packed_fwd.cu +++ b/gsplat/cuda/csrc/fully_fused_projection_packed_fwd.cu @@ -37,7 +37,10 @@ __global__ void fully_fused_projection_packed_fwd_kernel( T *__restrict__ means2d, // [nnz, 2] T *__restrict__ depths, // [nnz] T *__restrict__ conics, // [nnz, 3] - T *__restrict__ compensations // [nnz] optional + T *__restrict__ compensations, // [nnz] optional + T *__restrict__ ray_ts, // [nnz] + T *__restrict__ ray_planes, // [nnz, 2] + T *__restrict__ normals // [nnz, 2] ) { int32_t blocks_per_row = gridDim.x; @@ -52,6 +55,7 @@ __global__ void fully_fused_projection_packed_fwd_kernel( // check if points are with camera near and far plane vec3 mean_c; mat3 R; + T ray_t; if (valid) { // shift pointers to the current camera and gaussian means += col_idx * 3; @@ -66,6 +70,7 @@ __global__ void fully_fused_projection_packed_fwd_kernel( // transform Gaussian center to camera space pos_world_to_cam(R, t, glm::make_vec3(means), mean_c); + ray_t = glm::length(mean_c); if (mean_c.z < near_plane || mean_c.z > far_plane) { valid = false; } @@ -77,6 +82,8 @@ __global__ void fully_fused_projection_packed_fwd_kernel( mat2 covar2d_inv; T compensation; T det; + vec2 ray_plane; + vec3 normal; if (valid) { // transform Gaussian covariance to camera space mat3 covar; @@ -99,7 +106,7 @@ __global__ void fully_fused_projection_packed_fwd_kernel( // perspective projection Ks += row_idx * 9; persp_proj(mean_c, covar_c, Ks[0], Ks[4], Ks[2], Ks[5], image_width, - image_height, covar2d, mean2d); + image_height, covar2d, mean2d, ray_plane, normal); det = add_blur(eps2d, covar2d, compensation); if (det <= 0.f) { @@ -169,6 +176,12 @@ __global__ void fully_fused_projection_packed_fwd_kernel( if (compensations != nullptr) { compensations[thread_data] = compensation; } + ray_ts[thread_data] = ray_t; + ray_planes[thread_data * 2] = ray_plane.x; + ray_planes[thread_data * 2 + 1] = ray_plane.y; + normals[thread_data * 3] = normal.x; + normals[thread_data * 3 + 1] = normal.y; + normals[thread_data * 3 + 2] = normal.z; } // lane 0 of the first block in each row writes the indptr if (threadIdx.x == 0 && block_col_idx == 0) { @@ -183,7 +196,7 @@ __global__ void fully_fused_projection_packed_fwd_kernel( } std::tuple + torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> fully_fused_projection_packed_fwd_tensor( const torch::Tensor &means, // [N, 3] const at::optional &covars, // [N, 6] @@ -232,7 +245,7 @@ fully_fused_projection_packed_fwd_tensor( viewmats.data_ptr(), Ks.data_ptr(), image_width, image_height, eps2d, near_plane, far_plane, radius_clip, nullptr, block_cnts.data_ptr(), nullptr, nullptr, nullptr, nullptr, nullptr, - nullptr, nullptr, nullptr); + nullptr, nullptr, nullptr, nullptr, nullptr, nullptr); block_accum = torch::cumsum(block_cnts, 0, torch::kInt32); nnz = block_accum[-1].item(); } else { @@ -247,6 +260,9 @@ fully_fused_projection_packed_fwd_tensor( torch::Tensor means2d = torch::empty({nnz, 2}, means.options()); torch::Tensor depths = torch::empty({nnz}, means.options()); torch::Tensor conics = torch::empty({nnz, 3}, means.options()); + torch::Tensor ray_ts = torch::empty({nnz}, means.options()); + torch::Tensor ray_planes = torch::empty({nnz, 2}, means.options()); + torch::Tensor normals = torch::empty({nnz, 2}, means.options()); torch::Tensor compensations; if (calc_compensations) { // we dont want NaN to appear in this tensor, so we zero intialize it @@ -265,11 +281,14 @@ fully_fused_projection_packed_fwd_tensor( gaussian_ids.data_ptr(), radii.data_ptr(), means2d.data_ptr(), depths.data_ptr(), conics.data_ptr(), - calc_compensations ? compensations.data_ptr() : nullptr); + calc_compensations ? compensations.data_ptr() : nullptr, + ray_ts.data_ptr(), + ray_planes.data_ptr(), + normals.data_ptr()); } else { indptr.fill_(0); } return std::make_tuple(indptr, camera_ids, gaussian_ids, radii, means2d, depths, - conics, compensations); + conics, compensations, ray_ts, ray_planes, normals); } diff --git a/gsplat/cuda/csrc/persp_proj_bwd.cu b/gsplat/cuda/csrc/persp_proj_bwd.cu index 00341f771..fe8e5f340 100644 --- a/gsplat/cuda/csrc/persp_proj_bwd.cu +++ b/gsplat/cuda/csrc/persp_proj_bwd.cu @@ -24,6 +24,8 @@ persp_proj_bwd_kernel(const uint32_t C, const uint32_t N, const uint32_t width, const uint32_t height, const T *__restrict__ v_means2d, // [C, N, 2] const T *__restrict__ v_covars2d, // [C, N, 2, 2] + const T *__restrict__ v_ray_planes, // [C, N, 2] + const T *__restrict__ v_normals, // [C, N, 3] T *__restrict__ v_means, // [C, N, 3] T *__restrict__ v_covars // [C, N, 3, 3] ) { @@ -47,6 +49,8 @@ persp_proj_bwd_kernel(const uint32_t C, const uint32_t N, Ks += cid * 9; v_means2d += idx * 2; v_covars2d += idx * 4; + v_ray_planes += idx * 2; + v_normals += idx * 3; OpT fx = Ks[0], cx = Ks[2], fy = Ks[4], cy = Ks[5]; mat3 v_covar(0.f); @@ -55,9 +59,11 @@ persp_proj_bwd_kernel(const uint32_t C, const uint32_t N, const mat3 covar = glm::make_mat3(covars); const vec2 v_mean2d = glm::make_vec2(v_means2d); const mat2 v_covar2d = glm::make_mat2(v_covars2d); + const vec2 v_ray_plane = glm::make_vec2(v_ray_planes); + const vec3 v_normal = glm::make_vec3(v_normals); persp_proj_vjp(mean, covar, fx, fy, cx, cy, width, height, glm::transpose(v_covar2d), - v_mean2d, v_mean, v_covar); + v_mean2d, v_ray_plane, v_normal, v_mean, v_covar); // write to outputs: glm is column-major but we want row-major PRAGMA_UNROLL @@ -80,7 +86,9 @@ persp_proj_bwd_tensor(const torch::Tensor &means, // [C, N, 3] const torch::Tensor &Ks, // [C, 3, 3] const uint32_t width, const uint32_t height, const torch::Tensor &v_means2d, // [C, N, 2] - const torch::Tensor &v_covars2d // [C, N, 2, 2] + const torch::Tensor &v_covars2d, // [C, N, 2, 2] + const torch::Tensor &v_ray_planes, // [C, N, 2] + const torch::Tensor &v_normals // [C, N, 3] ) { DEVICE_GUARD(means); CHECK_INPUT(means); @@ -102,6 +110,7 @@ persp_proj_bwd_tensor(const torch::Tensor &means, // [C, N, 3] C, N, means.data_ptr(), covars.data_ptr(), Ks.data_ptr(), width, height, v_means2d.data_ptr(), v_covars2d.data_ptr(), v_means.data_ptr(), + v_ray_planes.data_ptr(), v_normals.data_ptr(), v_covars.data_ptr()); }); } diff --git a/gsplat/cuda/csrc/persp_proj_fwd.cu b/gsplat/cuda/csrc/persp_proj_fwd.cu index 2575efd2f..9e0676878 100644 --- a/gsplat/cuda/csrc/persp_proj_fwd.cu +++ b/gsplat/cuda/csrc/persp_proj_fwd.cu @@ -21,7 +21,9 @@ __global__ void persp_proj_fwd_kernel(const uint32_t C, const uint32_t N, const T *__restrict__ Ks, // [C, 3, 3] const uint32_t width, const uint32_t height, T *__restrict__ means2d, // [C, N, 2] - T *__restrict__ covars2d // [C, N, 2, 2] + T *__restrict__ covars2d, // [C, N, 2, 2] + T *__restrict__ ray_planes, // [C, N, 2] + T *__restrict__ normals // [C, N, 3] ) { // For now we'll upcast float16 and bfloat16 to float32 using OpT = typename OpType::type; @@ -40,13 +42,17 @@ __global__ void persp_proj_fwd_kernel(const uint32_t C, const uint32_t N, Ks += cid * 9; means2d += idx * 2; covars2d += idx * 4; + ray_planes += idx * 2; + normals += idx * 3; OpT fx = Ks[0], cx = Ks[2], fy = Ks[4], cy = Ks[5]; mat2 covar2d(0.f); vec2 mean2d(0.f); + vec2 ray_plane(0.f); + vec3 normal(0.f); const vec3 mean = glm::make_vec3(means); const mat3 covar = glm::make_mat3(covars); - persp_proj(mean, covar, fx, fy, cx, cy, width, height, covar2d, mean2d); + persp_proj(mean, covar, fx, fy, cx, cy, width, height, covar2d, mean2d, ray_plane, normal); // write to outputs: glm is column-major but we want row-major PRAGMA_UNROLL @@ -60,6 +66,14 @@ __global__ void persp_proj_fwd_kernel(const uint32_t C, const uint32_t N, for (uint32_t i = 0; i < 2; i++) { means2d[i] = T(mean2d[i]); } + PRAGMA_UNROLL + for (uint32_t i = 0; i < 2; i++) { + ray_planes[i] = T(ray_plane[i]); + } + PRAGMA_UNROLL + for (uint32_t i = 0; i < 2; i++) { + normals[i] = T(normal[i]); + } } std::tuple @@ -77,6 +91,8 @@ persp_proj_fwd_tensor(const torch::Tensor &means, // [C, N, 3] torch::Tensor means2d = torch::empty({C, N, 2}, means.options()); torch::Tensor covars2d = torch::empty({C, N, 2, 2}, covars.options()); + torch::Tensor ray_plane = torch::empty({C, N, 2}, covars.options()); + torch::Tensor normal = torch::empty({C, N, 3}, covars.options()); if (C && N) { at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream(); @@ -87,7 +103,8 @@ persp_proj_fwd_tensor(const torch::Tensor &means, // [C, N, 3] <<<(C * N + N_THREADS - 1) / N_THREADS, N_THREADS, 0, stream>>>( C, N, means.data_ptr(), covars.data_ptr(), Ks.data_ptr(), width, height, - means2d.data_ptr(), covars2d.data_ptr()); + means2d.data_ptr(), covars2d.data_ptr(), + ray_plane.data_ptr(), normal.data_ptr()); }); } return std::make_tuple(means2d, covars2d); diff --git a/gsplat/cuda/csrc/rasterize_to_pixels_bwd.cu b/gsplat/cuda/csrc/rasterize_to_pixels_bwd.cu index 0ec2870e8..4e4907089 100644 --- a/gsplat/cuda/csrc/rasterize_to_pixels_bwd.cu +++ b/gsplat/cuda/csrc/rasterize_to_pixels_bwd.cu @@ -11,7 +11,7 @@ namespace cg = cooperative_groups; * Rasterization to Pixels Backward Pass ****************************************************************************/ -template +template __global__ void rasterize_to_pixels_bwd_kernel( const uint32_t C, const uint32_t N, const uint32_t n_isects, const bool packed, // fwd inputs @@ -19,25 +19,36 @@ __global__ void rasterize_to_pixels_bwd_kernel( const vec3 *__restrict__ conics, // [C, N, 3] or [nnz, 3] const S *__restrict__ colors, // [C, N, COLOR_DIM] or [nnz, COLOR_DIM] const S *__restrict__ opacities, // [C, N] or [nnz] + const S *__restrict__ ray_ts, // [C, N] or [nnz] + const vec2 *__restrict__ ray_planes, // [C, N, 2] or [nnz, 2] + const vec3 *__restrict__ normals, // [C, N, 3] or [nnz, 3] const S *__restrict__ backgrounds, // [C, COLOR_DIM] or [nnz, COLOR_DIM] const bool *__restrict__ masks, // [C, tile_height, tile_width] const uint32_t image_width, const uint32_t image_height, const uint32_t tile_size, const uint32_t tile_width, const uint32_t tile_height, + const S *__restrict__ Ks, const int32_t *__restrict__ tile_offsets, // [C, tile_height, tile_width] const int32_t *__restrict__ flatten_ids, // [n_isects] // fwd outputs const S *__restrict__ render_alphas, // [C, image_height, image_width, 1] const int32_t *__restrict__ last_ids, // [C, image_height, image_width] + const int32_t *__restrict__ median_ids, // [C, image_height, image_width] // grad outputs const S *__restrict__ v_render_colors, // [C, image_height, image_width, // COLOR_DIM] const S *__restrict__ v_render_alphas, // [C, image_height, image_width, 1] + const S *__restrict__ v_render_expected_depths, // [C, image_height, image_width, 1] + const S *__restrict__ v_render_median_depths, // [C, image_height, image_width, 1] + const vec3 *__restrict__ v_render_expected_normals, // [C, image_height, image_width, 3] // grad inputs vec2 *__restrict__ v_means2d_abs, // [C, N, 2] or [nnz, 2] vec2 *__restrict__ v_means2d, // [C, N, 2] or [nnz, 2] vec3 *__restrict__ v_conics, // [C, N, 3] or [nnz, 3] S *__restrict__ v_colors, // [C, N, COLOR_DIM] or [nnz, COLOR_DIM] - S *__restrict__ v_opacities // [C, N] or [nnz] + S *__restrict__ v_opacities, // [C, N] or [nnz] + S *__restrict__ v_ray_ts, // [C, N] or [nnz] + S *__restrict__ v_ray_planes, // [C, N, 2] or [nnz, 2] + S *__restrict__ v_normals // [C, N, 3] or [nnz, 3] ) { auto block = cg::this_thread_block(); uint32_t camera_id = block.group_index().x; @@ -65,6 +76,15 @@ __global__ void rasterize_to_pixels_bwd_kernel( const S px = (S)j + 0.5f; const S py = (S)i + 0.5f; + + S ln; + if constexpr (GEO) + { + Ks += camera_id * 9; + S fx = Ks[0], cx = Ks[2], fy = Ks[4], cy = Ks[5]; + vec3 pixnf = {(px - cx) / fx, (py - cy) / fy, 1}; + ln = glm::length(pixnf); + } // clamp this value to the last pixel const int32_t pix_id = min(i * image_width + j, image_width * image_height - 1); @@ -90,12 +110,26 @@ __global__ void rasterize_to_pixels_bwd_kernel( vec3 *conic_batch = reinterpret_cast *>(&xy_opacity_batch[block_size]); // [block_size] S *rgbs_batch = (S *)&conic_batch[block_size]; // [block_size * COLOR_DIM] + S *ray_t_batch; + vec2 *ray_plane_batch; + vec3 *normal_batch; + if constexpr (GEO) + { + ray_t_batch = + reinterpret_cast(&rgbs_batch[block_size * COLOR_DIM]); // [block_size] + ray_plane_batch = + reinterpret_cast *>(&ray_t_batch[block_size]); // [block_size] + normal_batch = + reinterpret_cast *>(&ray_plane_batch[block_size]); // [block_size] + } // this is the T AFTER the last gaussian in this pixel S T_final = 1.0f - render_alphas[pix_id]; S T = T_final; // the contribution from gaussians behind the current one - S buffer[COLOR_DIM] = {0.f}; + S color_buffer[COLOR_DIM] = {0.f}; + S t_buffer = 0.f; + vec3 normal_buffer = {0.f, 0.f, 0.f}; // index of last gaussian to contribute to this pixel const int32_t bin_final = inside ? last_ids[pix_id] : 0; @@ -106,6 +140,18 @@ __global__ void rasterize_to_pixels_bwd_kernel( v_render_c[k] = v_render_colors[pix_id * COLOR_DIM + k]; } const S v_render_a = v_render_alphas[pix_id]; + S v_render_et, v_render_mt; + uint32_t median_idx; + vec3 v_render_en; + if constexpr (GEO) + { + S v_render_ed = v_render_expected_depths[pix_id]; + v_render_et = v_render_ed / ln; + S v_render_md = v_render_median_depths[pix_id]; + v_render_mt = v_render_md / ln; + v_render_en = v_render_expected_normals[pix_id]; + median_idx = median_ids[pix_id]; + } // collect and process batches of gaussians // each thread loads one gaussian at a time before rasterizing @@ -135,6 +181,12 @@ __global__ void rasterize_to_pixels_bwd_kernel( for (uint32_t k = 0; k < COLOR_DIM; ++k) { rgbs_batch[tr * COLOR_DIM + k] = colors[g * COLOR_DIM + k]; } + if constexpr (GEO) + { + ray_t_batch[tr] = ray_ts[g]; + ray_plane_batch[tr] = ray_planes[g]; + normal_batch[tr] = normals[g]; + } } // wait for other threads to collect the gaussians in batch block.sync(); @@ -174,6 +226,9 @@ __global__ void rasterize_to_pixels_bwd_kernel( vec3 v_conic_local = {0.f, 0.f, 0.f}; vec2 v_xy_local = {0.f, 0.f}; vec2 v_xy_abs_local = {0.f, 0.f}; + S v_ray_t_local = 0.f; + vec2 v_ray_plane_local = {0.f, 0.f}; + vec3 v_normal_local = {0.f, 0.f, 0.f}; S v_opacity_local = 0.f; // initialize everything to 0, only set if the lane is valid if (valid) { @@ -189,7 +244,7 @@ __global__ void rasterize_to_pixels_bwd_kernel( // contribution from this pixel S v_alpha = 0.f; for (uint32_t k = 0; k < COLOR_DIM; ++k) { - v_alpha += (rgbs_batch[t * COLOR_DIM + k] * T - buffer[k] * ra) * + v_alpha += (rgbs_batch[t * COLOR_DIM + k] * T - color_buffer[k] * ra) * v_render_c[k]; } @@ -204,6 +259,25 @@ __global__ void rasterize_to_pixels_bwd_kernel( v_alpha += -T_final * ra * accum; } + if constexpr (GEO) + { + v_normal_local = fac * v_render_en; + v_alpha += glm::dot(normal_batch[t] * T - normal_buffer * ra, v_render_en); + normal_buffer += normal_batch[t] * fac; + + v_ray_t_local = fac * v_render_et; + if (batch_end - t == median_idx) + { + v_ray_t_local += v_render_mt; + } + v_ray_plane_local = v_ray_t_local * delta; + const S ray_t = ray_t_batch[t]; + const vec2 ray_plane = ray_plane_batch[t]; + S t_opt = ray_t + glm::dot(delta, ray_plane); + v_alpha += (t_opt * T - t_buffer * ra) * v_render_et; + t_buffer += t_opt * fac; + } + if (opac * vis <= 0.999f) { const S v_sigma = -opac * vis * v_alpha; v_conic_local = {0.5f * v_sigma * delta.x * delta.x, @@ -219,12 +293,18 @@ __global__ void rasterize_to_pixels_bwd_kernel( PRAGMA_UNROLL for (uint32_t k = 0; k < COLOR_DIM; ++k) { - buffer[k] += rgbs_batch[t * COLOR_DIM + k] * fac; + color_buffer[k] += rgbs_batch[t * COLOR_DIM + k] * fac; } } warpSum(v_rgb_local, warp); warpSum(v_conic_local, warp); warpSum(v_xy_local, warp); + if constexpr (GEO) + { + warpSum(v_ray_t_local, warp); + warpSum(v_ray_plane_local, warp); + warpSum(v_normal_local, warp); + } if (v_means2d_abs != nullptr) { warpSum(v_xy_abs_local, warp); } @@ -252,33 +332,52 @@ __global__ void rasterize_to_pixels_bwd_kernel( gpuAtomicAdd(v_xy_abs_ptr + 1, v_xy_abs_local.y); } + if constexpr (GEO) { + S *v_ray_t_ptr = (S*)(v_ray_ts) + g; + gpuAtomicAdd(v_ray_t_ptr, v_ray_t_local); + S *v_ray_plane_ptr = (S*)(v_ray_planes) + 2 * g; + gpuAtomicAdd(v_ray_plane_ptr, v_ray_plane_local.x); + gpuAtomicAdd(v_ray_plane_ptr + 1, v_ray_plane_local.y); + S *v_normal_ptr = (S*)(v_normals) + 3 * g; + gpuAtomicAdd(v_normal_ptr, v_normal_local.x); + gpuAtomicAdd(v_normal_ptr + 1, v_normal_local.y); + gpuAtomicAdd(v_normal_ptr + 2, v_normal_local.z); + } + gpuAtomicAdd(v_opacities + g, v_opacity_local); } } } } -template -std::tuple -call_kernel_with_dim( +template +T call_kernel_with_dim( // Gaussian parameters const torch::Tensor &means2d, // [C, N, 2] or [nnz, 2] const torch::Tensor &conics, // [C, N, 3] or [nnz, 3] const torch::Tensor &colors, // [C, N, 3] or [nnz, 3] const torch::Tensor &opacities, // [C, N] or [nnz] + const torch::Tensor &ray_ts, // [C, N] or [nnz] + const torch::Tensor &ray_planes, // [C, N, 2] or [nnz, 2] + const torch::Tensor &normals, // [C, N, 3] or [nnz, 3] const at::optional &backgrounds, // [C, 3] const at::optional &masks, // [C, tile_height, tile_width] // image size const uint32_t image_width, const uint32_t image_height, const uint32_t tile_size, + const torch::Tensor &Ks, // intersections const torch::Tensor &tile_offsets, // [C, tile_height, tile_width] const torch::Tensor &flatten_ids, // [n_isects] // forward outputs const torch::Tensor &render_alphas, // [C, image_height, image_width, 1] const torch::Tensor &last_ids, // [C, image_height, image_width] + const torch::Tensor &median_ids, // [C, image_height, image_width] // gradients of outputs const torch::Tensor &v_render_colors, // [C, image_height, image_width, 3] const torch::Tensor &v_render_alphas, // [C, image_height, image_width, 1] + const torch::Tensor &v_render_expected_depths, // [C, image_height, image_width, 1] + const torch::Tensor &v_render_median_depths, // [C, image_height, image_width, 1] + const torch::Tensor &v_render_expected_normals, // [C, image_height, image_width, 3] // options bool absgrad) { @@ -302,6 +401,17 @@ call_kernel_with_dim( bool packed = means2d.dim() == 2; + constexpr unsigned int output_size = std::tuple_size_v; + constexpr bool GEO = output_size > 5; + if constexpr (GEO) + { + CHECK_INPUT(ray_ts); + CHECK_INPUT(ray_planes); + CHECK_INPUT(normals); + CHECK_INPUT(v_render_expected_depths); + CHECK_INPUT(v_render_expected_normals); + } + uint32_t C = tile_offsets.size(0); // number of cameras uint32_t N = packed ? 0 : means2d.size(1); // number of gaussians uint32_t n_isects = flatten_ids.size(0); @@ -322,63 +432,92 @@ call_kernel_with_dim( if (absgrad) { v_means2d_abs = torch::zeros_like(means2d); } + torch::Tensor v_ray_ts = GEO ? torch::zeros_like(ray_ts) : torch::Tensor(); + torch::Tensor v_ray_planes = GEO ? torch::zeros_like(ray_planes) : torch::Tensor(); + torch::Tensor v_normals = GEO ? torch::zeros_like(normals) : torch::Tensor(); + if (n_isects) { + // const uint32_t shared_mem = tile_size * tile_size * + // (sizeof(int32_t) + sizeof(vec3) + + // sizeof(vec3) + sizeof(float) * COLOR_DIM); const uint32_t shared_mem = tile_size * tile_size * - (sizeof(int32_t) + sizeof(vec3) + - sizeof(vec3) + sizeof(float) * COLOR_DIM); + (GEO ? + sizeof(int32_t) + sizeof(vec3) + sizeof(vec3) + sizeof(float) + sizeof(vec2) + sizeof(vec3) + sizeof(float) * COLOR_DIM + : sizeof(int32_t) + sizeof(vec3) + sizeof(vec3) + sizeof(float) * COLOR_DIM); at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream(); - if (cudaFuncSetAttribute(rasterize_to_pixels_bwd_kernel, + if (cudaFuncSetAttribute(rasterize_to_pixels_bwd_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem) != cudaSuccess) { AT_ERROR("Failed to set maximum shared memory size (requested ", shared_mem, " bytes), try lowering tile_size."); } - rasterize_to_pixels_bwd_kernel + rasterize_to_pixels_bwd_kernel <<>>( C, N, n_isects, packed, reinterpret_cast *>(means2d.data_ptr()), reinterpret_cast *>(conics.data_ptr()), colors.data_ptr(), opacities.data_ptr(), - backgrounds.has_value() ? backgrounds.value().data_ptr() - : nullptr, + GEO ? ray_ts.data_ptr() : nullptr, + GEO ? reinterpret_cast *>(ray_planes.data_ptr()) : nullptr, + GEO ? reinterpret_cast *>(normals.data_ptr()) : nullptr, + backgrounds.has_value() ? backgrounds.value().data_ptr() : nullptr, masks.has_value() ? masks.value().data_ptr(): nullptr, image_width, image_height, tile_size, tile_width, tile_height, + GEO ? Ks.data_ptr() : nullptr, tile_offsets.data_ptr(), flatten_ids.data_ptr(), - render_alphas.data_ptr(), last_ids.data_ptr(), + render_alphas.data_ptr(), last_ids.data_ptr(), + GEO ? median_ids.data_ptr() : nullptr, v_render_colors.data_ptr(), v_render_alphas.data_ptr(), + GEO ? v_render_expected_depths.data_ptr() : nullptr, + GEO ? v_render_median_depths.data_ptr() : nullptr, + GEO ? reinterpret_cast *>(v_render_expected_normals.data_ptr()) : nullptr, absgrad ? reinterpret_cast *>(v_means2d_abs.data_ptr()) : nullptr, reinterpret_cast *>(v_means2d.data_ptr()), reinterpret_cast *>(v_conics.data_ptr()), - v_colors.data_ptr(), v_opacities.data_ptr()); + v_colors.data_ptr(), v_opacities.data_ptr(), + GEO ? v_ray_ts.data_ptr() : nullptr, + GEO ? v_ray_planes.data_ptr() : nullptr, + GEO ? v_normals.data_ptr() : nullptr); } - - return std::make_tuple(v_means2d_abs, v_means2d, v_conics, v_colors, v_opacities); + if constexpr (GEO) + return std::make_tuple(v_means2d_abs, v_means2d, v_conics, v_colors, v_opacities, v_ray_ts, v_ray_planes, v_normals); + else + return std::make_tuple(v_means2d_abs, v_means2d, v_conics, v_colors, v_opacities); } -std::tuple -rasterize_to_pixels_bwd_tensor( +// std::tuple +template +T rasterize_to_pixels_bwd_tensor( // Gaussian parameters const torch::Tensor &means2d, // [C, N, 2] or [nnz, 2] const torch::Tensor &conics, // [C, N, 3] or [nnz, 3] const torch::Tensor &colors, // [C, N, 3] or [nnz, 3] const torch::Tensor &opacities, // [C, N] or [nnz] + const torch::Tensor &ray_ts, // [C, N] or [nnz] + const torch::Tensor &ray_planes, // [C, N, 2] or [nnz, 2] + const torch::Tensor &normals, // [C, N, 3] or [nnz, 3] const at::optional &backgrounds, // [C, 3] const at::optional &masks, // [C, tile_height, tile_width] // image size const uint32_t image_width, const uint32_t image_height, const uint32_t tile_size, + const torch::Tensor &Ks, // intersections const torch::Tensor &tile_offsets, // [C, tile_height, tile_width] const torch::Tensor &flatten_ids, // [n_isects] // forward outputs const torch::Tensor &render_alphas, // [C, image_height, image_width, 1] const torch::Tensor &last_ids, // [C, image_height, image_width] + const torch::Tensor &median_ids, // [C, image_height, image_width] // gradients of outputs const torch::Tensor &v_render_colors, // [C, image_height, image_width, 3] const torch::Tensor &v_render_alphas, // [C, image_height, image_width, 1] + const torch::Tensor &v_render_expected_depths, // [C, image_height, image_width, 1] + const torch::Tensor &v_render_median_depths, // [C, image_height, image_width, 1] + const torch::Tensor &v_render_expected_normals, // [C, image_height, image_width, 3] // options bool absgrad) { @@ -387,10 +526,12 @@ rasterize_to_pixels_bwd_tensor( #define __GS__CALL_(N) \ case N: \ - return call_kernel_with_dim( \ - means2d, conics, colors, opacities, backgrounds, masks, image_width, \ - image_height, tile_size, tile_offsets, flatten_ids, render_alphas, \ - last_ids, v_render_colors, v_render_alphas, absgrad); + return call_kernel_with_dim( \ + means2d, conics, colors, opacities, ray_ts, ray_planes, normals, \ + backgrounds, masks, image_width, image_height, tile_size, Ks, \ + tile_offsets, flatten_ids, render_alphas, \ + last_ids, median_ids, v_render_colors, v_render_alphas, \ + v_render_expected_depths, v_render_median_depths, v_render_expected_normals, absgrad); switch (COLOR_DIM) { __GS__CALL_(1) @@ -416,3 +557,76 @@ rasterize_to_pixels_bwd_tensor( AT_ERROR("Unsupported number of channels: ", COLOR_DIM); } } + +std::tuple +rasterize_to_pixels_w_depth_bwd_tensor( + // Gaussian parameters + const torch::Tensor &means2d, // [C, N, 2] or [nnz, 2] + const torch::Tensor &conics, // [C, N, 3] or [nnz, 3] + const torch::Tensor &colors, // [C, N, 3] or [nnz, 3] + const torch::Tensor &opacities, // [C, N] or [nnz] + const torch::Tensor &ray_ts, // [C, N] or [nnz] + const torch::Tensor &ray_planes, // [C, N, 2] or [nnz, 2] + const torch::Tensor &normals, // [C, N, 3] or [nnz, 3] + const at::optional &backgrounds, // [C, 3] + const at::optional &masks, // [C, tile_height, tile_width] + // image size + const uint32_t image_width, const uint32_t image_height, const uint32_t tile_size, + const torch::Tensor &Ks, + // intersections + const torch::Tensor &tile_offsets, // [C, tile_height, tile_width] + const torch::Tensor &flatten_ids, // [n_isects] + // forward outputs + const torch::Tensor &render_alphas, // [C, image_height, image_width, 1] + const torch::Tensor &last_ids, // [C, image_height, image_width] + const torch::Tensor &median_ids, // [C, image_height, image_width] + // gradients of outputs + const torch::Tensor &v_render_colors, // [C, image_height, image_width, 3] + const torch::Tensor &v_render_alphas, // [C, image_height, image_width, 1] + const torch::Tensor &v_render_expected_depths, // [C, image_height, image_width, 1] + const torch::Tensor &v_render_median_depths, // [C, image_height, image_width, 1] + const torch::Tensor &v_render_expected_normals, // [C, image_height, image_width, 3] + // options + bool absgrad) { + return rasterize_to_pixels_bwd_tensor> + ( + means2d, conics, colors, opacities, ray_ts, ray_planes, normals, + backgrounds, masks, image_width, image_height, tile_size, Ks, + tile_offsets, flatten_ids, render_alphas, + last_ids, median_ids, v_render_colors, v_render_alphas, + v_render_expected_depths, v_render_median_depths, v_render_expected_normals, absgrad + ); +} + +std::tuple +rasterize_to_pixels_wo_depth_bwd_tensor( + // Gaussian parameters + const torch::Tensor &means2d, // [C, N, 2] or [nnz, 2] + const torch::Tensor &conics, // [C, N, 3] or [nnz, 3] + const torch::Tensor &colors, // [C, N, 3] or [nnz, 3] + const torch::Tensor &opacities, // [C, N] or [nnz] + const at::optional &backgrounds, // [C, 3] + const at::optional &masks, // [C, tile_height, tile_width] + // image size + const uint32_t image_width, const uint32_t image_height, const uint32_t tile_size, + // intersections + const torch::Tensor &tile_offsets, // [C, tile_height, tile_width] + const torch::Tensor &flatten_ids, // [n_isects] + // forward outputs + const torch::Tensor &render_alphas, // [C, image_height, image_width, 1] + const torch::Tensor &last_ids, // [C, image_height, image_width] + // gradients of outputs + const torch::Tensor &v_render_colors, // [C, image_height, image_width, 3] + const torch::Tensor &v_render_alphas, // [C, image_height, image_width, 1] + // options + bool absgrad) { + torch::Tensor Ks, ray_ts, ray_planes, normals, median_ids, v_render_expected_depths, v_render_median_depths, v_render_expected_normals; + return rasterize_to_pixels_bwd_tensor> + ( + means2d, conics, colors, opacities, ray_ts, ray_planes, normals, + backgrounds, masks, image_width, image_height, tile_size, Ks, + tile_offsets, flatten_ids, render_alphas, + last_ids, median_ids, v_render_colors, v_render_alphas, + v_render_expected_depths, v_render_median_depths, v_render_expected_normals, absgrad + ); +} \ No newline at end of file diff --git a/gsplat/cuda/csrc/rasterize_to_pixels_fwd.cu b/gsplat/cuda/csrc/rasterize_to_pixels_fwd.cu index 6a89b7148..14cab079b 100644 --- a/gsplat/cuda/csrc/rasterize_to_pixels_fwd.cu +++ b/gsplat/cuda/csrc/rasterize_to_pixels_fwd.cu @@ -11,21 +11,30 @@ namespace cg = cooperative_groups; * Rasterization to Pixels Forward Pass ****************************************************************************/ -template +template __global__ void rasterize_to_pixels_fwd_kernel( const uint32_t C, const uint32_t N, const uint32_t n_isects, const bool packed, + const vec2 *__restrict__ means2d, // [C, N, 2] or [nnz, 2] const vec3 *__restrict__ conics, // [C, N, 3] or [nnz, 3] const S *__restrict__ colors, // [C, N, COLOR_DIM] or [nnz, COLOR_DIM] const S *__restrict__ opacities, // [C, N] or [nnz] + const S *__restrict__ ray_ts, // [C, N] or [nnz] + const vec2 *__restrict__ ray_planes, // [C, N, 2] or [nnz, 2] + const vec3 *__restrict__ normals, // [C, N, 3] or [nnz, 3] const S *__restrict__ backgrounds, // [C, COLOR_DIM] const bool *__restrict__ masks, // [C, tile_height, tile_width] + const S *__restrict__ Ks, // [C, 9] const uint32_t image_width, const uint32_t image_height, const uint32_t tile_size, const uint32_t tile_width, const uint32_t tile_height, const int32_t *__restrict__ tile_offsets, // [C, tile_height, tile_width] const int32_t *__restrict__ flatten_ids, // [n_isects] S *__restrict__ render_colors, // [C, image_height, image_width, COLOR_DIM] S *__restrict__ render_alphas, // [C, image_height, image_width, 1] + S *__restrict__ render_depths, // [C, image_height, image_width, 1] + S *__restrict__ median_depths, // [C, image_height, image_width, 1] + S *__restrict__ render_normals,// [C, image_height, image_width, 3] + int32_t *__restrict__ median_ids, // [C, image_height, image_width] int32_t *__restrict__ last_ids // [C, image_height, image_width] ) { // each thread draws one pixel, but also timeshares caching gaussians in a @@ -51,6 +60,15 @@ __global__ void rasterize_to_pixels_fwd_kernel( S px = (S)j + 0.5f; S py = (S)i + 0.5f; int32_t pix_id = i * image_width + j; + S ln; + if constexpr (GEO) + { + Ks += camera_id * 9; + S fx = Ks[0], cx = Ks[2], fy = Ks[4], cy = Ks[5]; + vec3 pixnf = {(px - cx) / fx, (py - cy) / fy, 1}; + ln = glm::length(pixnf); + } + // return if out of bounds // keep not rasterizing threads around for reading data @@ -83,7 +101,18 @@ __global__ void rasterize_to_pixels_fwd_kernel( reinterpret_cast *>(&id_batch[block_size]); // [block_size] vec3 *conic_batch = reinterpret_cast *>(&xy_opacity_batch[block_size]); // [block_size] - + S *ray_t_batch; + vec2 *ray_plane_batch; + vec3 *normal_batch; + if constexpr (GEO) + { + ray_t_batch = + reinterpret_cast(&conic_batch[block_size]); // [block_size] + ray_plane_batch = + reinterpret_cast *>(&ray_t_batch[block_size]); // [block_size] + normal_batch = + reinterpret_cast *>(&ray_plane_batch[block_size]); // [block_size] + } // current visibility left to render // transmittance is gonna be used in the backward pass which requires a high // numerical precision so we use double for it. However double make bwd 1.5x slower @@ -91,6 +120,7 @@ __global__ void rasterize_to_pixels_fwd_kernel( S T = 1.0f; // index of most recent gaussian to write to this thread's pixel uint32_t cur_idx = 0; + uint32_t median_idx = -1; // collect and process batches of gaussians // each thread loads one gaussian at a time before rasterizing its @@ -98,6 +128,9 @@ __global__ void rasterize_to_pixels_fwd_kernel( uint32_t tr = block.thread_rank(); S pix_out[COLOR_DIM] = {0.f}; + S t_out = 0.f; + S normal_out[3] = {0.f}; + S t_median = 0.f; for (uint32_t b = 0; b < num_batches; ++b) { // resync all threads before beginning next batch // end early if entire tile is done @@ -116,6 +149,12 @@ __global__ void rasterize_to_pixels_fwd_kernel( const S opac = opacities[g]; xy_opacity_batch[tr] = {xy.x, xy.y, opac}; conic_batch[tr] = conics[g]; + if constexpr (GEO) + { + ray_t_batch[tr] = ray_ts[g]; + ray_plane_batch[tr] = ray_planes[g]; + normal_batch[tr] = normals[g]; + } } // wait for other threads to collect the gaussians in batch @@ -149,6 +188,26 @@ __global__ void rasterize_to_pixels_fwd_kernel( for (uint32_t k = 0; k < COLOR_DIM; ++k) { pix_out[k] += c_ptr[k] * vis; } + if constexpr (GEO) + { + const S ray_t = ray_t_batch[t]; + const vec2 ray_plane = ray_plane_batch[t]; + const vec3 normal = normal_batch[t]; + PRAGMA_UNROLL + for (uint32_t k = 0; k < 3; ++k) { + normal_out[k] += normal[k] * vis; + } + + S t_opt = ray_t + glm::dot(delta, ray_plane); + t_out += t_opt * vis; + if (T > 0.5) + { + median_idx = batch_start + t; + t_median = t_opt; + } + // printf("%f %f %f %f %f %f\n",depth_out,t_opt,ray_t,ln,ray_plane.x,ray_plane.y); + } + cur_idx = batch_start + t; T = next_T; @@ -167,22 +226,36 @@ __global__ void rasterize_to_pixels_fwd_kernel( render_colors[pix_id * COLOR_DIM + k] = backgrounds == nullptr ? pix_out[k] : (pix_out[k] + T * backgrounds[k]); } + if constexpr (GEO) + { + render_depths[pix_id] = t_out / ln; + median_depths[pix_id] = t_median / ln; + median_ids[pix_id] = median_idx; + PRAGMA_UNROLL + for (uint32_t k = 0; k < 3; ++k){ + render_normals[pix_id * 3 + k] = normal_out[k]; + } + } // index in bin of last gaussian in this pixel last_ids[pix_id] = static_cast(cur_idx); } } -template -std::tuple call_kernel_with_dim( +template +T call_kernel_with_dim( // Gaussian parameters - const torch::Tensor &means2d, // [C, N, 2] or [nnz, 2] - const torch::Tensor &conics, // [C, N, 3] or [nnz, 3] - const torch::Tensor &colors, // [C, N, channels] or [nnz, channels] - const torch::Tensor &opacities, // [C, N] or [nnz] + const torch::Tensor &means2d, // [C, N, 2] or [nnz, 2] + const torch::Tensor &conics, // [C, N, 3] or [nnz, 3] + const torch::Tensor &colors, // [C, N, channels] or [nnz, channels] + const torch::Tensor &opacities, // [C, N] or [nnz] + const torch::Tensor &ray_ts, // [C, N] or [nnz] + const torch::Tensor &ray_planes, // [C, N, 2] or [nnz, 2] + const torch::Tensor &normals, // [C, N, 3] or [nnz, 3] const at::optional &backgrounds, // [C, channels] const at::optional &masks, // [C, tile_height, tile_width] // image size const uint32_t image_width, const uint32_t image_height, const uint32_t tile_size, + const torch::Tensor &Ks, // intersections const torch::Tensor &tile_offsets, // [C, tile_height, tile_width] const torch::Tensor &flatten_ids // [n_isects] @@ -202,6 +275,15 @@ std::tuple call_kernel_with_dim( } bool packed = means2d.dim() == 2; + constexpr unsigned int output_size = std::tuple_size_v; + constexpr bool GEO = output_size > 3; + if constexpr (GEO) + { + CHECK_INPUT(ray_ts); + CHECK_INPUT(ray_planes); + CHECK_INPUT(normals); + } + uint32_t C = tile_offsets.size(0); // number of cameras uint32_t N = packed ? 0 : means2d.size(1); // number of gaussians uint32_t channels = colors.size(-1); @@ -218,61 +300,88 @@ std::tuple call_kernel_with_dim( means2d.options().dtype(torch::kFloat32)); torch::Tensor alphas = torch::empty({C, image_height, image_width, 1}, means2d.options().dtype(torch::kFloat32)); + torch::Tensor expected_depths = GEO ? torch::empty({C, image_height, image_width, 1}, + means2d.options().dtype(torch::kFloat32)) : torch::Tensor(); + torch::Tensor median_depths = GEO ? torch::empty({C, image_height, image_width, 1}, + means2d.options().dtype(torch::kFloat32)) : torch::Tensor(); + torch::Tensor expected_normals = GEO ? torch::empty({C, image_height, image_width, 3}, + means2d.options().dtype(torch::kFloat32)) : torch::Tensor(); + torch::Tensor median_ids = torch::empty({C, image_height, image_width}, + means2d.options().dtype(torch::kInt32)); torch::Tensor last_ids = torch::empty({C, image_height, image_width}, means2d.options().dtype(torch::kInt32)); at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream(); - const uint32_t shared_mem = + const uint32_t shared_mem = tile_size * tile_size * - (sizeof(int32_t) + sizeof(vec3) + sizeof(vec3)); + (GEO ? sizeof(int32_t) + sizeof(vec3) + sizeof(vec3) + sizeof(float) + sizeof(vec2) + sizeof(vec3) + :sizeof(int32_t) + sizeof(vec3) + sizeof(vec3)); // TODO: an optimization can be done by passing the actual number of channels into // the kernel functions and avoid necessary global memory writes. This requires // moving the channel padding from python to C side. - if (cudaFuncSetAttribute(rasterize_to_pixels_fwd_kernel, + if (cudaFuncSetAttribute(rasterize_to_pixels_fwd_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem) != cudaSuccess) { AT_ERROR("Failed to set maximum shared memory size (requested ", shared_mem, " bytes), try lowering tile_size."); } - rasterize_to_pixels_fwd_kernel + + rasterize_to_pixels_fwd_kernel <<>>( C, N, n_isects, packed, reinterpret_cast *>(means2d.data_ptr()), reinterpret_cast *>(conics.data_ptr()), colors.data_ptr(), opacities.data_ptr(), + GEO ? ray_ts.data_ptr() : nullptr, + GEO ? reinterpret_cast *>(ray_planes.data_ptr()) : nullptr, + GEO ? reinterpret_cast *>(normals.data_ptr()) : nullptr, backgrounds.has_value() ? backgrounds.value().data_ptr() : nullptr, masks.has_value() ? masks.value().data_ptr() : nullptr, + GEO ? Ks.data_ptr() : nullptr, image_width, image_height, tile_size, tile_width, tile_height, tile_offsets.data_ptr(), flatten_ids.data_ptr(), renders.data_ptr(), alphas.data_ptr(), + GEO ? expected_depths.data_ptr() : nullptr, + GEO ? median_depths.data_ptr() : nullptr, + GEO ? expected_normals.data_ptr() : nullptr, + GEO ? median_ids.data_ptr() : nullptr, last_ids.data_ptr()); - - return std::make_tuple(renders, alphas, last_ids); + + if constexpr (GEO) + return std::make_tuple(renders, alphas, expected_depths, median_depths, expected_normals, median_ids, last_ids); + else + return std::make_tuple(renders, alphas, last_ids); } -std::tuple rasterize_to_pixels_fwd_tensor( +template +T rasterize_to_pixels_fwd( // Gaussian parameters - const torch::Tensor &means2d, // [C, N, 2] or [nnz, 2] - const torch::Tensor &conics, // [C, N, 3] or [nnz, 3] - const torch::Tensor &colors, // [C, N, channels] or [nnz, channels] - const torch::Tensor &opacities, // [C, N] or [nnz] + const torch::Tensor &means2d, // [C, N, 2] or [nnz, 2] + const torch::Tensor &conics, // [C, N, 3] or [nnz, 3] + const torch::Tensor &colors, // [C, N, channels] or [nnz, channels] + const torch::Tensor &opacities, // [C, N] or [nnz] + const torch::Tensor &ray_ts, // [C, N] or [nnz] + const torch::Tensor &ray_planes, // [C, N, 2] or [nnz, 2] + const torch::Tensor &normals, // [C, N, 3] or [nnz, 3] const at::optional &backgrounds, // [C, channels] const at::optional &masks, // [C, tile_height, tile_width] // image size const uint32_t image_width, const uint32_t image_height, const uint32_t tile_size, + const torch::Tensor &Ks, // intersections const torch::Tensor &tile_offsets, // [C, tile_height, tile_width] const torch::Tensor &flatten_ids // [n_isects] -) { +){ CHECK_INPUT(colors); uint32_t channels = colors.size(-1); #define __GS__CALL_(N) \ case N: \ - return call_kernel_with_dim(means2d, conics, colors, opacities, \ - backgrounds, masks, image_width, image_height, \ - tile_size, tile_offsets, flatten_ids); + return call_kernel_with_dim(means2d, conics, colors, opacities, \ + ray_ts, ray_planes, normals, \ + backgrounds, masks, image_width, image_height, \ + tile_size, Ks, tile_offsets, flatten_ids); // TODO: an optimization can be done by passing the actual number of channels into // the kernel functions and avoid necessary global memory writes. This requires @@ -301,3 +410,54 @@ std::tuple rasterize_to_pixels_fwd_ AT_ERROR("Unsupported number of channels: ", channels); } } + + +std::tuple +rasterize_to_pixels_wo_depth_fwd_tensor( + // Gaussian parameters + const torch::Tensor &means2d, // [C, N, 2] or [nnz, 2] + const torch::Tensor &conics, // [C, N, 3] or [nnz, 3] + const torch::Tensor &colors, // [C, N, channels] or [nnz, channels] + const torch::Tensor &opacities, // [C, N] or [nnz] + const at::optional &backgrounds, // [C, channels] + const at::optional &masks, // [C, tile_height, tile_width] + // image size + const uint32_t image_width, const uint32_t image_height, const uint32_t tile_size, + // intersections + const torch::Tensor &tile_offsets, // [C, tile_height, tile_width] + const torch::Tensor &flatten_ids // [n_isects] +){ + torch::Tensor ray_ts; + torch::Tensor ray_planes; + torch::Tensor normals; + torch::Tensor Ks; + return rasterize_to_pixels_fwd> + (means2d, conics, colors, opacities, ray_ts, ray_planes, normals, + backgrounds, masks, image_width, image_height, tile_size, + Ks, tile_offsets, flatten_ids); +} + +std::tuple +rasterize_to_pixels_w_depth_fwd_tensor( + // Gaussian parameters + const torch::Tensor &means2d, // [C, N, 2] or [nnz, 2] + const torch::Tensor &conics, // [C, N, 3] or [nnz, 3] + const torch::Tensor &colors, // [C, N, channels] or [nnz, channels] + const torch::Tensor &opacities, // [C, N] or [nnz] + const torch::Tensor &ray_ts, // [C, N] or [nnz] + const torch::Tensor &ray_planes, // [C, N, 2] or [nnz, 2] + const torch::Tensor &normals, // [C, N, 3] or [nnz, 3] + const at::optional &backgrounds, // [C, channels] + const at::optional &masks, // [C, tile_height, tile_width] + // image size + const uint32_t image_width, const uint32_t image_height, const uint32_t tile_size, + const torch::Tensor &Ks, + // intersections + const torch::Tensor &tile_offsets, // [C, tile_height, tile_width] + const torch::Tensor &flatten_ids // [n_isects] +){ + return rasterize_to_pixels_fwd> + (means2d, conics, colors, opacities, ray_ts, ray_planes, normals, + backgrounds, masks, image_width, image_height, tile_size, + Ks, tile_offsets, flatten_ids); +} \ No newline at end of file diff --git a/gsplat/cuda/csrc/symeigen.cuh b/gsplat/cuda/csrc/symeigen.cuh new file mode 100644 index 000000000..b900862ee --- /dev/null +++ b/gsplat/cuda/csrc/symeigen.cuh @@ -0,0 +1,227 @@ +#ifndef SYMEIGEN +#define SYMEIGEN +#include + +namespace glm_modification +{ + // Incorporate the transferSign, pythag, equal, and findEigenvaluesSymReal functions from the glm library, + // with small modifications on findEgienvaluesSymReal to ensure numerical stability for big Gaussian kernels. + // https://github.com/g-truc/glm/blob/33b4a621a697a305bc3a7610d290677b96beb181/glm/gtx/pca.inl + // https://github.com/g-truc/glm/blob/33b4a621a697a305bc3a7610d290677b96beb181/glm/ext/scalar_relational.inl + template + __forceinline__ __device__ bool equal(genType const& x, genType const& y, genType const& epsilon) + { + return abs(x - y) <= epsilon; + } + + template + __forceinline__ __device__ static T transferSign(T const& v, T const& s) + { + return ((s) >= 0 ? glm::abs(v) : -glm::abs(v)); + } + + template + __forceinline__ __device__ static T pythag(T const& a, T const& b) { + static const T epsilon = static_cast(0.0000001); + T absa = glm::abs(a); + T absb = glm::abs(b); + if(absa > absb) { + absb /= absa; + absb *= absb; + return absa * glm::sqrt(static_cast(1) + absb); + } + if(glm_modification::equal(absb, 0, epsilon)) return static_cast(0); + absa /= absb; + absa *= absa; + return absb * glm::sqrt(static_cast(1) + absa); + } + + + template + __forceinline__ __device__ unsigned int findEigenvaluesSymReal + ( + glm::mat const& covarMat, + glm::vec& outEigenvalues, + glm::mat& outEigenvectors + ) + { + + T a[D * D]; // matrix -- input and workspace for algorithm (will be changed inplace) + T d[D]; // diagonal elements + T e[D]; // off-diagonal elements + + for(glm::length_t r = 0; r < D; r++) + for(glm::length_t c = 0; c < D; c++) + a[(r) * D + (c)] = covarMat[c][r]; + + // 1. Householder reduction. + glm::length_t l, k, j, i; + T scale, hh, h, g, f; + static const T epsilon = static_cast(0.0000001); + + for(i = D; i >= 2; i--) + { + l = i - 1; + h = scale = 0; + if(l > 1) + { + for(k = 1; k <= l; k++) + { + scale += glm::abs(a[(i - 1) * D + (k - 1)]); + } + if(glm_modification::equal(scale, 0, epsilon)) + { + e[i - 1] = a[(i - 1) * D + (l - 1)]; + } + else + { + for(k = 1; k <= l; k++) + { + a[(i - 1) * D + (k - 1)] /= scale; + h += a[(i - 1) * D + (k - 1)] * a[(i - 1) * D + (k - 1)]; + } + f = a[(i - 1) * D + (l - 1)]; + g = ((f >= 0) ? -glm::sqrt(h) : glm::sqrt(h)); + e[i - 1] = scale * g; + h -= f * g; + a[(i - 1) * D + (l - 1)] = f - g; + f = 0; + for(j = 1; j <= l; j++) + { + a[(j - 1) * D + (i - 1)] = a[(i - 1) * D + (j - 1)] / h; + g = 0; + for(k = 1; k <= j; k++) + { + g += a[(j - 1) * D + (k - 1)] * a[(i - 1) * D + (k - 1)]; + } + for(k = j + 1; k <= l; k++) + { + g += a[(k - 1) * D + (j - 1)] * a[(i - 1) * D + (k - 1)]; + } + e[j - 1] = g / h; + f += e[j - 1] * a[(i - 1) * D + (j - 1)]; + } + hh = f / (h + h); + for(j = 1; j <= l; j++) + { + f = a[(i - 1) * D + (j - 1)]; + e[j - 1] = g = e[j - 1] - hh * f; + for(k = 1; k <= j; k++) + { + a[(j - 1) * D + (k - 1)] -= (f * e[k - 1] + g * a[(i - 1) * D + (k - 1)]); + } + } + } + } + else + { + e[i - 1] = a[(i - 1) * D + (l - 1)]; + } + d[i - 1] = h; + } + d[0] = 0; + e[0] = 0; + for(i = 1; i <= D; i++) + { + l = i - 1; + if(!glm_modification::equal(d[i - 1], 0, epsilon)) + { + for(j = 1; j <= l; j++) + { + g = 0; + for(k = 1; k <= l; k++) + { + g += a[(i - 1) * D + (k - 1)] * a[(k - 1) * D + (j - 1)]; + } + for(k = 1; k <= l; k++) + { + a[(k - 1) * D + (j - 1)] -= g * a[(k - 1) * D + (i - 1)]; + } + } + } + d[i - 1] = a[(i - 1) * D + (i - 1)]; + a[(i - 1) * D + (i - 1)] = 1; + for(j = 1; j <= l; j++) + { + a[(j - 1) * D + (i - 1)] = a[(i - 1) * D + (j - 1)] = 0; + } + } + + // 2. Calculation of eigenvalues and eigenvectors (QL algorithm) + glm::length_t m, iter; + T s, r, p, dd, c, b; + const glm::length_t MAX_ITER = 30; + + for(i = 2; i <= D; i++) + { + e[i - 2] = e[i - 1]; + } + e[D - 1] = 0; + + for(l = 1; l <= D; l++) + { + iter = 0; + do + { + for(m = l; m <= D - 1; m++) + { + dd = glm::abs(d[m - 1]) + glm::abs(d[m - 1 + 1]); + if(glm_modification::equal(glm::abs(e[m - 1]), 0, epsilon)) + break; + } + if(m != l) + { + if(iter++ == MAX_ITER) + { + return 0; // Too many iterations in FindEigenvalues + } + g = (d[l - 1 + 1] - d[l - 1]) / (2 * e[l - 1]); + r = pythag(g, 1); + g = d[m - 1] - d[l - 1] + e[l - 1] / (g + transferSign(r, g)); + s = c = 1; + p = 0; + for(i = m - 1; i >= l; i--) + { + f = s * e[i - 1]; + b = c * e[i - 1]; + e[i - 1 + 1] = r = pythag(f, g); + if(glm_modification::equal(r, 0, epsilon)) + { + d[i - 1 + 1] -= p; + e[m - 1] = 0; + break; + } + s = f / r; + c = g / r; + g = d[i - 1 + 1] - p; + r = (d[i - 1] - g) * s + 2 * c * b; + d[i - 1 + 1] = g + (p = s * r); + g = c * r - b; + for(k = 1; k <= D; k++) + { + f = a[(k - 1) * D + (i - 1 + 1)]; + a[(k - 1) * D + (i - 1 + 1)] = s * a[(k - 1) * D + (i - 1)] + c * f; + a[(k - 1) * D + (i - 1)] = c * a[(k - 1) * D + (i - 1)] - s * f; + } + } + if(glm_modification::equal(r, 0, epsilon) && (i >= l)) + continue; + d[l - 1] -= p; + e[l - 1] = g; + e[m - 1] = 0; + } + } while(m != l); + } + + // 3. output + for(i = 0; i < D; i++) + outEigenvalues[i] = d[i]; + for(i = 0; i < D; i++) + for(j = 0; j < D; j++) + outEigenvectors[i][j] = a[(j) * D + (i)]; + + return D; + } +} + +#endif \ No newline at end of file diff --git a/gsplat/cuda/csrc/utils.cuh b/gsplat/cuda/csrc/utils.cuh index 7e23a0f8e..f69a90173 100644 --- a/gsplat/cuda/csrc/utils.cuh +++ b/gsplat/cuda/csrc/utils.cuh @@ -2,6 +2,7 @@ #define GSPLAT_CUDA_UTILS_H #include "helpers.cuh" +#include "symeigen.cuh" #include #include @@ -148,7 +149,7 @@ inline __device__ void persp_proj( const vec3 mean3d, const mat3 cov3d, const T fx, const T fy, const T cx, const T cy, const uint32_t width, const uint32_t height, // outputs - mat2 &cov2d, vec2 &mean2d) { + mat2 &cov2d, vec2 &mean2d, vec2 &ray_plane, vec3 &normal) { T x = mean3d[0], y = mean3d[1], z = mean3d[2]; T tan_fovx = 0.5f * width / fx; @@ -158,8 +159,10 @@ inline __device__ void persp_proj( T rz = 1.f / z; T rz2 = rz * rz; - T tx = z * min(lim_x, max(-lim_x, x * rz)); - T ty = z * min(lim_y, max(-lim_y, y * rz)); + T u = min(lim_x, max(-lim_x, x * rz)); + T v = min(lim_y, max(-lim_y, y * rz)); + T tx = z * u; + T ty = z * v; // mat3x2 is 3 columns x 2 rows. mat3x2 J = mat3x2(fx * rz, 0.f, // 1st column @@ -168,6 +171,60 @@ inline __device__ void persp_proj( ); cov2d = J * cov3d * glm::transpose(J); mean2d = vec2({fx * x * rz + cx, fy * y * rz + cy}); + + // calculate the ray space intersection plane. + auto length = [](T x, T y, T z) { return sqrt(x*x+y*y+z*z); }; + mat3 cov3d_eigen_vector; + vec3 cov3d_eigen_value; + int D = glm_modification::findEigenvaluesSymReal(cov3d,cov3d_eigen_value,cov3d_eigen_vector); + unsigned int min_id = cov3d_eigen_value[0]>cov3d_eigen_value[1]? (cov3d_eigen_value[1]>cov3d_eigen_value[2]?2:1):(cov3d_eigen_value[0]>cov3d_eigen_value[2]?2:0); + mat3 cov3d_inv; + bool well_conditioned = cov3d_eigen_value[min_id]>1E-8; + vec3 eigenvector_min; + if(well_conditioned) + { + mat3 diag = mat3( 1/cov3d_eigen_value[0], 0, 0, + 0, 1/cov3d_eigen_value[1], 0, + 0, 0, 1/cov3d_eigen_value[2] ); + cov3d_inv = cov3d_eigen_vector * diag * glm::transpose(cov3d_eigen_vector); + } + else + { + eigenvector_min = cov3d_eigen_vector[min_id]; + cov3d_inv = glm::outerProduct(eigenvector_min,eigenvector_min); + } + vec3 uvh = {u, v, 1}; + vec3 Cinv_uvh = cov3d_inv * uvh; + if(length(Cinv_uvh.x, Cinv_uvh.y, Cinv_uvh.z) < 1E-12 || D ==0) + { + normal = {0, 0, 0}; + ray_plane = {0, 0}; + } + else + { + T l = length(tx, ty, z); + mat3 nJ_T = glm::mat3(rz, 0.f, -tx * rz2, // 1st column + 0.f, rz, -ty * rz2, // 2nd column + tx/l, ty/l, z/l // 3rd column + ); + T uu = u * u; + T vv = v * v; + T uv = u * v; + + mat3x2 nJ_inv_T = mat3x2(vv + 1, -uv, // 1st column + -uv, uu + 1, // 2nd column + -u, -v // 3nd column + ); + T factor = l / (uu + vv + 1); + vec3 Cinv_uvh_n = glm::normalize(Cinv_uvh); + T u_Cinv_u_n_clmap = max(glm::dot(Cinv_uvh_n, uvh), 1E-7); + vec2 plane = nJ_inv_T * (Cinv_uvh_n / u_Cinv_u_n_clmap); + vec3 ray_normal_vector = {-plane.x*factor, -plane.y*factor, -1}; + vec3 cam_normal_vector = nJ_T * ray_normal_vector; + normal = glm::normalize(cam_normal_vector); + ray_plane = {plane.x * factor / fx, plane.y * factor / fy}; + } + } template @@ -176,7 +233,7 @@ inline __device__ void persp_proj_vjp( const vec3 mean3d, const mat3 cov3d, const T fx, const T fy, const T cx, const T cy, const uint32_t width, const uint32_t height, // grad outputs - const mat2 v_cov2d, const vec2 v_mean2d, + const mat2 v_cov2d, const vec2 v_mean2d, const vec2 v_ray_plane, const vec3 v_normal, // grad inputs vec3 &v_mean3d, mat3 &v_cov3d) { T x = mean3d[0], y = mean3d[1], z = mean3d[2]; @@ -188,8 +245,11 @@ inline __device__ void persp_proj_vjp( T rz = 1.f / z; T rz2 = rz * rz; - T tx = z * min(lim_x, max(-lim_x, x * rz)); - T ty = z * min(lim_y, max(-lim_y, y * rz)); + T u = min(lim_x, max(-lim_x, x * rz)); + T v = min(lim_y, max(-lim_y, y * rz)); + T tx = z * u; + T ty = z * v; + mat3 v_cov3d_ = {0,0,0,0,0,0,0,0,0}; // mat3x2 is 3 columns x 2 rows. mat3x2 J = mat3x2(fx * rz, 0.f, // 1st column @@ -197,10 +257,124 @@ inline __device__ void persp_proj_vjp( -fx * tx * rz2, -fy * ty * rz2 // 3rd column ); + // calculate the ray space intersection plane. + auto length = [](T x, T y, T z) { return sqrt(x*x+y*y+z*z); }; + mat3 cov3d_eigen_vector; + vec3 cov3d_eigen_value; + int D = glm_modification::findEigenvaluesSymReal(cov3d,cov3d_eigen_value,cov3d_eigen_vector); + unsigned int min_id = cov3d_eigen_value[0]>cov3d_eigen_value[1]? (cov3d_eigen_value[1]>cov3d_eigen_value[2]?2:1):(cov3d_eigen_value[0]>cov3d_eigen_value[2]?2:0); + mat3 cov3d_inv; + bool well_conditioned = cov3d_eigen_value[min_id]>1E-8; + vec3 eigenvector_min; + if(well_conditioned) + { + mat3 diag = mat3( 1/cov3d_eigen_value[0], 0, 0, + 0, 1/cov3d_eigen_value[1], 0, + 0, 0, 1/cov3d_eigen_value[2] ); + cov3d_inv = cov3d_eigen_vector * diag * glm::transpose(cov3d_eigen_vector); + } + else + { + eigenvector_min = cov3d_eigen_vector[min_id]; + cov3d_inv = glm::outerProduct(eigenvector_min,eigenvector_min); + } + vec3 uvh = {u, v, 1}; + vec3 Cinv_uvh = cov3d_inv * uvh; + T l, v_u, v_v, v_l; + mat3 v_nJ_T; + if(length(Cinv_uvh.x, Cinv_uvh.y, Cinv_uvh.z) < 1E-12 || D ==0) + { + l = 1.f; + v_u = 0.f; + v_v = 0.f; + v_l = 0.f; + v_nJ_T = {0,0,0,0,0,0,0,0,0}; + } + else + { + l = length(tx, ty, z); + mat3 nJ_T = glm::mat3(rz, 0.f, -tx * rz2, // 1st column + 0.f, rz, -ty * rz2, // 2nd column + tx/l, ty/l, z/l // 3rd column + ); + T uu = u * u; + T vv = v * v; + T uv = u * v; + + mat3x2 nJ_inv_T = mat3x2(vv + 1, -uv, // 1st column + -uv, uu + 1, // 2nd column + -u, -v // 3nd column + ); + const T nl = uu + vv + 1; + T factor = l / nl; + vec3 Cinv_uvh_n = glm::normalize(Cinv_uvh); + T u_Cinv_u = glm::dot(Cinv_uvh, uvh); + T u_Cinv_u_n = glm::dot(Cinv_uvh_n, uvh); + T u_Cinv_u_clmap = max(u_Cinv_u, 1E-7); + T u_Cinv_u_n_clmap = max(u_Cinv_u_n, 1E-7); + mat3 cov3d_inv_u_Cinv_u = cov3d_inv / u_Cinv_u_clmap; + vec3 Cinv_uvh_u_Cinv_u = Cinv_uvh_n / u_Cinv_u_n_clmap; + vec2 plane = nJ_inv_T * Cinv_uvh_u_Cinv_u; + vec3 ray_normal_vector = {-plane.x*factor, -plane.y*factor, -1}; + vec3 cam_normal_vector = nJ_T * ray_normal_vector; + vec3 normal = glm::normalize(cam_normal_vector); + vec2 ray_plane = {plane.x * factor / fx, plane.y * factor / fy}; + + T cam_normal_vector_length = glm::length(cam_normal_vector); + + vec3 v_normal_l = v_normal / cam_normal_vector_length; + vec3 v_cam_normal_vector = v_normal_l - normal * glm::dot(normal,v_normal_l); + vec3 v_ray_normal_vector = glm::transpose(nJ_T) * v_cam_normal_vector; + v_nJ_T = glm::outerProduct(v_cam_normal_vector, ray_normal_vector); + + // ray_plane_uv = {plane.x * factor, plane.y * factor}; + const vec2 v_ray_plane_uv = {v_ray_plane.x / fx, v_ray_plane.y / fy}; + v_l = glm::dot(plane, -glm::make_vec2(v_ray_normal_vector) + v_ray_plane_uv) / nl; + vec2 v_plane = {factor * (-v_ray_normal_vector.x + v_ray_plane_uv.x), + factor * (-v_ray_normal_vector.y + v_ray_plane_uv.y)}; + T v_nl = (-v_ray_normal_vector.x * ray_normal_vector.x - v_ray_normal_vector.y * ray_normal_vector.y + -v_ray_plane.x*ray_plane.x - v_ray_plane.y*ray_plane.y) / nl; + + T tmp = glm::dot(v_plane, plane); + if(well_conditioned) + { + v_cov3d_ += -glm::outerProduct(Cinv_uvh, + cov3d_inv_u_Cinv_u * (uvh * (-tmp) + glm::transpose(nJ_inv_T) * v_plane)); + } + else + { + mat3 v_cov3d_inv = glm::outerProduct(uvh, + (-tmp * uvh + glm::transpose(nJ_inv_T) * v_plane) / u_Cinv_u_clmap); + vec3 v_eigenvector_min = (v_cov3d_inv + glm::transpose(v_cov3d_inv)) * eigenvector_min; + for(int j =0;j<3;j++) + { + if(j!=min_id) + { + T scale = glm::dot(cov3d_eigen_vector[j], v_eigenvector_min)/min(cov3d_eigen_value[min_id] - cov3d_eigen_value[j], - 0.0000001f); + v_cov3d_ += glm::outerProduct(cov3d_eigen_vector[j] * scale, eigenvector_min); + } + } + } + + vec3 v_uvh = cov3d_inv_u_Cinv_u * (2 * (-tmp) * uvh + glm::transpose(nJ_inv_T) * v_plane); + mat3x2 v_nJ_inv_T = glm::outerProduct(v_plane, Cinv_uvh_u_Cinv_u); + + // derivative of u v in factor, uvh and nJ_inv_T variables. + v_u = v_nl * 2 * u //dnl/du + + v_uvh.x + + (v_nJ_inv_T[0][1] + v_nJ_inv_T[1][0]) * (-v) + 2 * v_nJ_inv_T[1][1] * u - v_nJ_inv_T[2][0]; + v_v = v_nl * 2 * v //dnl/du + + v_uvh.y + + (v_nJ_inv_T[0][1] + v_nJ_inv_T[1][0]) * (-u) + 2 * v_nJ_inv_T[0][0] * v - v_nJ_inv_T[2][1]; + + } + // cov = J * V * Jt; G = df/dcov = v_cov // -> df/dV = Jt * G * J // -> df/dJ = G * J * Vt + Gt * J * V - v_cov3d += glm::transpose(J) * v_cov2d * J; + v_cov3d_ += glm::transpose(J) * v_cov2d * J; + + v_cov3d += v_cov3d_; // df/dx = fx * rz * df/dpixx // df/dy = fy * rz * df/dpixy @@ -217,18 +391,31 @@ inline __device__ void persp_proj_vjp( v_cov2d * J * glm::transpose(cov3d) + glm::transpose(v_cov2d) * J * cov3d; // fov clipping + T l3 = l * l * l; + T v_mean3d_x = -fx * rz2 * v_J[2][0] + v_u * rz + - v_nJ_T[0][2]*rz2 + v_nJ_T[2][0]*(1/l-tx*tx/l3) + (v_nJ_T[2][1] * tx + v_nJ_T[2][2] * z)*(-tx/l3) + + v_l * tx / l; + T v_mean3d_y = -fy * rz2 * v_J[2][1] + v_v * rz + - v_nJ_T[1][2]*rz2 + (v_nJ_T[2][0]* tx + v_nJ_T[2][2]* z) *(-ty/l3) + v_nJ_T[2][1]*(1/l-ty*ty/l3) + + v_l * ty / l; if (x * rz <= lim_x && x * rz >= -lim_x) { - v_mean3d.x += -fx * rz2 * v_J[2][0]; + v_mean3d.x += v_mean3d_x; } else { - v_mean3d.z += -fx * rz3 * v_J[2][0] * tx; + // v_mean3d.z += -fx * rz3 * v_J[2][0] * tx; + v_mean3d.z += v_mean3d_x * u; } if (y * rz <= lim_y && y * rz >= -lim_y) { - v_mean3d.y += -fy * rz2 * v_J[2][1]; + v_mean3d.y += v_mean3d_y; } else { - v_mean3d.z += -fy * rz3 * v_J[2][1] * ty; + // v_mean3d.z += -fy * rz3 * v_J[2][1] * ty; + v_mean3d.z += v_mean3d_y * v; } - v_mean3d.z += -fx * rz2 * v_J[0][0] - fy * rz2 * v_J[1][1] + - 2.f * fx * tx * rz3 * v_J[2][0] + 2.f * fy * ty * rz3 * v_J[2][1]; + v_mean3d.z += -fx * rz2 * v_J[0][0] - fy * rz2 * v_J[1][1] + + 2.f * fx * tx * rz3 * v_J[2][0] + 2.f * fy * ty * rz3 * v_J[2][1] + - (v_u * tx + v_v * ty) * rz2 + + v_nJ_T[0][0] * (-rz2) + v_nJ_T[1][1] * (-rz2) + v_nJ_T[0][2] * (2 * tx * rz3) + v_nJ_T[1][2] * (2 * ty * rz3) + + (v_nJ_T[2][0] * tx + v_nJ_T[2][1] * ty) * (-z/l3) + v_nJ_T[2][2] * (1 / l - z * z / l3) + + v_l * z / l; } template diff --git a/gsplat/rendering.py b/gsplat/rendering.py index da0d0ed74..b207e8307 100644 --- a/gsplat/rendering.py +++ b/gsplat/rendering.py @@ -36,6 +36,7 @@ def rasterization( sparse_grad: bool = False, absgrad: bool = False, rasterize_mode: Literal["classic", "antialiased"] = "classic", + require_rade: bool = False, channel_chunk: int = 32, ) -> Tuple[Tensor, Tensor, Dict]: """Rasterize a set of 3D Gaussians (N) to a batch of image planes (C). @@ -243,11 +244,14 @@ def rasterization( depths, conics, compensations, + ray_ts, + ray_planes, + normals ) = proj_results opacities = opacities[gaussian_ids] # [nnz] else: # The results are with shape [C, N, ...]. Only the elements with radii > 0 are valid. - radii, means2d, depths, conics, compensations = proj_results + radii, means2d, depths, conics, compensations, ray_ts, ray_planes, normals = proj_results opacities = opacities.repeat(C, 1) # [C, N] camera_ids, gaussian_ids = None, None @@ -338,39 +342,58 @@ def rasterization( if backgrounds is not None else None ) - render_colors_, render_alphas_ = rasterize_to_pixels( + require_rade_ = require_rade and (not i) + render_colors_, render_alphas_, median_depths_, expected_depth_, expected_normal_ = rasterize_to_pixels( means2d, conics, colors_chunk, opacities, + ray_ts, + ray_planes, + normals, width, height, tile_size, isect_offsets, flatten_ids, + Ks = Ks, backgrounds=backgrounds_chunk, packed=packed, absgrad=absgrad, + require_geo=require_rade_ ) render_colors.append(render_colors_) render_alphas.append(render_alphas_) + if require_rade_: + expected_depth = expected_depth_/render_alphas_.clamp(min=1e-10) + median_depths = median_depths_ + expected_normal = torch.nn.functional.normalize(expected_normal_,dim=-1) + render_colors = torch.cat(render_colors, dim=-1) render_alphas = render_alphas[0] # discard the rest else: - render_colors, render_alphas = rasterize_to_pixels( + render_colors, render_alphas, expected_depth, median_depths, expected_normal = rasterize_to_pixels( means2d, conics, colors, opacities, + ray_ts, + ray_planes, + normals, width, height, tile_size, isect_offsets, flatten_ids, + Ks = Ks, backgrounds=backgrounds, packed=packed, absgrad=absgrad, + require_geo=require_rade ) + if require_rade: + expected_depth = expected_depth/render_alphas.clamp(min=1e-10) + expected_normal = torch.nn.functional.normalize(expected_normal,dim=-1) if render_mode in ["ED", "RGB+ED"]: # normalize the accumulated depth to get the expected depth render_colors = torch.cat( @@ -400,7 +423,7 @@ def rasterization( "tile_size": tile_size, "n_cameras": C, } - return render_colors, render_alphas, meta + return render_colors, render_alphas, expected_depth, median_depths, expected_normal, meta def _rasterization( diff --git a/train_dtu.sh b/train_dtu.sh new file mode 100644 index 000000000..0f60fd498 --- /dev/null +++ b/train_dtu.sh @@ -0,0 +1,20 @@ +for i in 122 118 114 110 106 105 97 83 69 65 63 55 40 37 24 +do + python examples/simple_trainer_recon.py --eval_steps 100 2000 7000 15000 20000 25000 \ + --disable_viewer --data_factor 2 \ + --data_dir /media/super/data/dataset/dtu/DTU_mask/scan$i/ \ + --result_dir output/scan$i/ \ + --normal_consistency_loss \ + --app_opt \ + --test_every 1000000000 \ + --absgrad + python examples/simple_trainer_recon.py --eval_steps 100 2000 7000 15000 20000 25000 \ + --disable_viewer --data_factor 2 \ + --data_dir /media/super/data/dataset/dtu/DTU_mask/scan$i/ \ + --result_dir output//scan$i/ \ + --normal_consistency_loss \ + --app_opt \ + --test_every 1000000000 \ + --absgrad \ + --ckpt output/scan$i/ckpts/ckpt_29999.pt +done \ No newline at end of file From eb150f8de699f0dc6bfe61b21a75feca3cac401c Mon Sep 17 00:00:00 2001 From: BaowenZ <947976219@qq.com> Date: Thu, 8 Aug 2024 23:01:14 +0800 Subject: [PATCH 2/2] fix persp_proj_vjp --- gsplat/cuda/csrc/utils.cuh | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/gsplat/cuda/csrc/utils.cuh b/gsplat/cuda/csrc/utils.cuh index f69a90173..cc39c8f42 100644 --- a/gsplat/cuda/csrc/utils.cuh +++ b/gsplat/cuda/csrc/utils.cuh @@ -318,7 +318,7 @@ inline __device__ void persp_proj_vjp( vec3 ray_normal_vector = {-plane.x*factor, -plane.y*factor, -1}; vec3 cam_normal_vector = nJ_T * ray_normal_vector; vec3 normal = glm::normalize(cam_normal_vector); - vec2 ray_plane = {plane.x * factor / fx, plane.y * factor / fy}; + // vec2 ray_plane = {plane.x * factor / fx, plane.y * factor / fy}; T cam_normal_vector_length = glm::length(cam_normal_vector); @@ -327,13 +327,13 @@ inline __device__ void persp_proj_vjp( vec3 v_ray_normal_vector = glm::transpose(nJ_T) * v_cam_normal_vector; v_nJ_T = glm::outerProduct(v_cam_normal_vector, ray_normal_vector); - // ray_plane_uv = {plane.x * factor, plane.y * factor}; + vec2 ray_plane_uv = {plane.x * factor, plane.y * factor}; const vec2 v_ray_plane_uv = {v_ray_plane.x / fx, v_ray_plane.y / fy}; v_l = glm::dot(plane, -glm::make_vec2(v_ray_normal_vector) + v_ray_plane_uv) / nl; vec2 v_plane = {factor * (-v_ray_normal_vector.x + v_ray_plane_uv.x), factor * (-v_ray_normal_vector.y + v_ray_plane_uv.y)}; T v_nl = (-v_ray_normal_vector.x * ray_normal_vector.x - v_ray_normal_vector.y * ray_normal_vector.y - -v_ray_plane.x*ray_plane.x - v_ray_plane.y*ray_plane.y) / nl; + -v_ray_plane.x*ray_plane_uv.x - v_ray_plane.y*ray_plane_uv.y) / nl; T tmp = glm::dot(v_plane, plane); if(well_conditioned) @@ -376,10 +376,11 @@ inline __device__ void persp_proj_vjp( v_cov3d += v_cov3d_; + vec3 v_mean3d_ = {0,0,0}; // df/dx = fx * rz * df/dpixx // df/dy = fy * rz * df/dpixy // df/dz = - fx * mean.x * rz2 * df/dpixx - fy * mean.y * rz2 * df/dpixy - v_mean3d += vec3(fx * rz * v_mean2d[0], fy * rz * v_mean2d[1], + v_mean3d_ += vec3(fx * rz * v_mean2d[0], fy * rz * v_mean2d[1], -(fx * x * v_mean2d[0] + fy * y * v_mean2d[1]) * rz2); // df/dx = -fx * rz2 * df/dJ_02 @@ -393,29 +394,30 @@ inline __device__ void persp_proj_vjp( // fov clipping T l3 = l * l * l; T v_mean3d_x = -fx * rz2 * v_J[2][0] + v_u * rz - - v_nJ_T[0][2]*rz2 + v_nJ_T[2][0]*(1/l-tx*tx/l3) + (v_nJ_T[2][1] * tx + v_nJ_T[2][2] * z)*(-tx/l3) + - v_nJ_T[0][2]*rz2 + v_nJ_T[2][0]*(1/l-tx*tx/l3) + (v_nJ_T[2][1] * ty + v_nJ_T[2][2] * z)*(-tx/l3) + v_l * tx / l; T v_mean3d_y = -fy * rz2 * v_J[2][1] + v_v * rz - v_nJ_T[1][2]*rz2 + (v_nJ_T[2][0]* tx + v_nJ_T[2][2]* z) *(-ty/l3) + v_nJ_T[2][1]*(1/l-ty*ty/l3) + v_l * ty / l; if (x * rz <= lim_x && x * rz >= -lim_x) { - v_mean3d.x += v_mean3d_x; + v_mean3d_.x += v_mean3d_x; } else { // v_mean3d.z += -fx * rz3 * v_J[2][0] * tx; - v_mean3d.z += v_mean3d_x * u; + v_mean3d_.z += v_mean3d_x * u; } if (y * rz <= lim_y && y * rz >= -lim_y) { - v_mean3d.y += v_mean3d_y; + v_mean3d_.y += v_mean3d_y; } else { // v_mean3d.z += -fy * rz3 * v_J[2][1] * ty; - v_mean3d.z += v_mean3d_y * v; + v_mean3d_.z += v_mean3d_y * v; } - v_mean3d.z += -fx * rz2 * v_J[0][0] - fy * rz2 * v_J[1][1] + v_mean3d_.z += -fx * rz2 * v_J[0][0] - fy * rz2 * v_J[1][1] + 2.f * fx * tx * rz3 * v_J[2][0] + 2.f * fy * ty * rz3 * v_J[2][1] - (v_u * tx + v_v * ty) * rz2 + v_nJ_T[0][0] * (-rz2) + v_nJ_T[1][1] * (-rz2) + v_nJ_T[0][2] * (2 * tx * rz3) + v_nJ_T[1][2] * (2 * ty * rz3) + (v_nJ_T[2][0] * tx + v_nJ_T[2][1] * ty) * (-z/l3) + v_nJ_T[2][2] * (1 / l - z * z / l3) + v_l * z / l; + v_mean3d += v_mean3d_; } template