diff --git a/assets/camtoworlds.pt b/assets/camtoworlds.pt new file mode 100644 index 000000000..68a959b9b Binary files /dev/null and b/assets/camtoworlds.pt differ diff --git a/assets/means.pt b/assets/means.pt new file mode 100644 index 000000000..15645be5b Binary files /dev/null and b/assets/means.pt differ diff --git a/assets/quats.pt b/assets/quats.pt new file mode 100644 index 000000000..951f070ff Binary files /dev/null and b/assets/quats.pt differ diff --git a/assets/radii.pt b/assets/radii.pt new file mode 100644 index 000000000..fbc9a9af1 Binary files /dev/null and b/assets/radii.pt differ diff --git a/assets/scales.pt b/assets/scales.pt new file mode 100644 index 000000000..d2182ab2d Binary files /dev/null and b/assets/scales.pt differ diff --git a/examples/benchmark_mipnerf360.py b/examples/benchmark_mipnerf360.py new file mode 100644 index 000000000..62d39c5ea --- /dev/null +++ b/examples/benchmark_mipnerf360.py @@ -0,0 +1,93 @@ +# Training script for the Mip-NeRF 360 dataset + +import os +import GPUtil +from concurrent.futures import ThreadPoolExecutor +import time +import glob + +# 9 scenes +# scenes = ["bicycle", "bonsai", "counter", "flowers", "garden", "stump", "treehill", "kitchen", "room"] +# factors = [4, 2, 2, 4, 4, 4, 4, 2, 2] + +# 7 scenes +scenes = ["bicycle", "bonsai", "counter", "garden", "stump", "kitchen", "room"] +factors = [4, 2, 2, 4, 4, 2, 2] + +excluded_gpus = set([]) + +result_dir = "results/benchmark_mipsplatting_cuda3D" + +dry_run = False + +jobs = list(zip(scenes, factors)) + + +def train_scene(gpu, scene, factor): + # train without eval + cmd = f"OMP_NUM_THREADS=4 CUDA_VISIBLE_DEVICES={gpu} python simple_trainer_mip_splatting.py --eval_steps -1 --disable_viewer --data_factor {factor} --data_dir data/360_v2/{scene} --result_dir {result_dir}/{scene} --antialiased --kernel_size 0.1" + print(cmd) + if not dry_run: + os.system(cmd) + + # eval and render for all the ckpts + ckpts = glob.glob(f"{result_dir}/{scene}/ckpts/*.pt") + for ckpt in ckpts: + cmd = f"OMP_NUM_THREADS=4 CUDA_VISIBLE_DEVICES={gpu} python simple_trainer_mip_splatting.py --disable_viewer --data_factor {factor} --data_dir data/360_v2/{scene} --result_dir {result_dir}/{scene} --ckpt {ckpt} --antialiased --kernel_size 0.1" + print(cmd) + if not dry_run: + os.system(cmd) + + return True + + +def worker(gpu, scene, factor): + print(f"Starting job on GPU {gpu} with scene {scene}\n") + train_scene(gpu, scene, factor) + print(f"Finished job on GPU {gpu} with scene {scene}\n") + # This worker function starts a job and returns when it's done. + + +def dispatch_jobs(jobs, executor): + future_to_job = {} + reserved_gpus = set() # GPUs that are slated for work but may not be active yet + + while jobs or future_to_job: + # Get the list of available GPUs, not including those that are reserved. + all_available_gpus = set( + GPUtil.getAvailable(order="first", limit=10, maxMemory=0.1, maxLoad=0.1) + ) + # all_available_gpus = set([0,1,2,3]) + available_gpus = list(all_available_gpus - reserved_gpus - excluded_gpus) + + # Launch new jobs on available GPUs + while available_gpus and jobs: + gpu = available_gpus.pop(0) + job = jobs.pop(0) + future = executor.submit( + worker, gpu, *job + ) # Unpacking job as arguments to worker + future_to_job[future] = (gpu, job) + + reserved_gpus.add(gpu) # Reserve this GPU until the job starts processing + + # Check for completed jobs and remove them from the list of running jobs. + # Also, release the GPUs they were using. + done_futures = [future for future in future_to_job if future.done()] + for future in done_futures: + job = future_to_job.pop( + future + ) # Remove the job associated with the completed future + gpu = job[0] # The GPU is the first element in each job tuple + reserved_gpus.discard(gpu) # Release this GPU + print(f"Job {job} has finished., rellasing GPU {gpu}") + # (Optional) You might want to introduce a small delay here to prevent this loop from spinning very fast + # when there are no GPUs available. + time.sleep(5) + + print("All jobs have been processed.") + + +# Using ThreadPoolExecutor to manage the thread pool +with ThreadPoolExecutor(max_workers=8) as executor: + dispatch_jobs(jobs, executor) diff --git a/examples/benchmark_mipnerf360_stmt.py b/examples/benchmark_mipnerf360_stmt.py new file mode 100644 index 000000000..28952adad --- /dev/null +++ b/examples/benchmark_mipnerf360_stmt.py @@ -0,0 +1,115 @@ +# Training script for the Mip-NeRF 360 dataset +# The model is trained with downsampling factor 8 and rendered with downsampling factor 1, 2, 4, 8 + +import os +import GPUtil +from concurrent.futures import ThreadPoolExecutor +import time +import glob + +# 9 scenes +# scenes = ["bicycle", "bonsai", "counter", "flowers", "garden", "stump", "treehill", "kitchen", "room"] +# factors = [4, 2, 2, 4, 4, 4, 4, 2, 2] + +# 7 scenes +scenes = ["bicycle", "bonsai", "counter", "garden", "stump", "kitchen", "room"] +factors = [8] * len(scenes) + +excluded_gpus = set([]) + +# classic +result_dir = "results/benchmark_stmt" +# antialiased +result_dir = "results/benchmark_antialiased_stmt" +# mip-splatting +# result_dir = "results/benchmark_mipsplatting_stmt" + +dry_run = False + +jobs = list(zip(scenes, factors)) + + +def train_scene(gpu, scene, factor): + # train without eval + # classic + # cmd = f"OMP_NUM_THREADS=4 CUDA_VISIBLE_DEVICES={gpu} python simple_trainer.py --eval_steps -1 --disable_viewer --data_factor {factor} --data_dir data/360_v2/{scene} --result_dir {result_dir}/{scene}" + + # anti-aliased + # cmd = f"OMP_NUM_THREADS=4 CUDA_VISIBLE_DEVICES={gpu} python simple_trainer.py --eval_steps -1 --disable_viewer --data_factor {factor} --data_dir data/360_v2/{scene} --result_dir {result_dir}/{scene} --antialiased" + + # mip-splatting + cmd = f"OMP_NUM_THREADS=4 CUDA_VISIBLE_DEVICES={gpu} python simple_trainer_mip_splatting.py --eval_steps -1 --disable_viewer --data_factor {factor} --data_dir data/360_v2/{scene} --result_dir {result_dir}/{scene} --antialiased --kernel_size 0.1" + print(cmd) + if not dry_run: + os.system(cmd) + + # eval and render for all the ckpts + ckpts = glob.glob(f"{result_dir}/{scene}/ckpts/*.pt") + for ckpt in ckpts: + for test_factor in [1, 2, 4, 8]: + # classic + # cmd = f"OMP_NUM_THREADS=4 CUDA_VISIBLE_DEVICES={gpu} python simple_trainer.py --disable_viewer --data_factor {test_factor} --data_dir data/360_v2/{scene} --result_dir {result_dir}/{scene}_{test_factor} --ckpt {ckpt}" + + # anti-aliased + # cmd = f"OMP_NUM_THREADS=4 CUDA_VISIBLE_DEVICES={gpu} python simple_trainer.py --disable_viewer --data_factor {test_factor} --data_dir data/360_v2/{scene} --result_dir {result_dir}/{scene}_{test_factor} --ckpt {ckpt} --antialiased" + + # mip-splatting + cmd = f"OMP_NUM_THREADS=4 CUDA_VISIBLE_DEVICES={gpu} python simple_trainer_mip_splatting.py --disable_viewer --data_factor {test_factor} --data_dir data/360_v2/{scene} --result_dir {result_dir}/{scene}_{test_factor} --ckpt {ckpt} --antialiased --kernel_size 0.1" + print(cmd) + if not dry_run: + os.system(cmd) + + return True + + +def worker(gpu, scene, factor): + print(f"Starting job on GPU {gpu} with scene {scene}\n") + train_scene(gpu, scene, factor) + print(f"Finished job on GPU {gpu} with scene {scene}\n") + # This worker function starts a job and returns when it's done. + + +def dispatch_jobs(jobs, executor): + future_to_job = {} + reserved_gpus = set() # GPUs that are slated for work but may not be active yet + + while jobs or future_to_job: + # Get the list of available GPUs, not including those that are reserved. + all_available_gpus = set( + GPUtil.getAvailable(order="first", limit=10, maxMemory=0.1, maxLoad=0.1) + ) + # all_available_gpus = set([0,1,2,3]) + available_gpus = list(all_available_gpus - reserved_gpus - excluded_gpus) + + # Launch new jobs on available GPUs + while available_gpus and jobs: + gpu = available_gpus.pop(0) + job = jobs.pop(0) + future = executor.submit( + worker, gpu, *job + ) # Unpacking job as arguments to worker + future_to_job[future] = (gpu, job) + + reserved_gpus.add(gpu) # Reserve this GPU until the job starts processing + time.sleep(2) + + # Check for completed jobs and remove them from the list of running jobs. + # Also, release the GPUs they were using. + done_futures = [future for future in future_to_job if future.done()] + for future in done_futures: + job = future_to_job.pop( + future + ) # Remove the job associated with the completed future + gpu = job[0] # The GPU is the first element in each job tuple + reserved_gpus.discard(gpu) # Release this GPU + print(f"Job {job} has finished., rellasing GPU {gpu}") + # (Optional) You might want to introduce a small delay here to prevent this loop from spinning very fast + # when there are no GPUs available. + time.sleep(5) + + print("All jobs have been processed.") + + +# Using ThreadPoolExecutor to manage the thread pool +with ThreadPoolExecutor(max_workers=8) as executor: + dispatch_jobs(jobs, executor) diff --git a/examples/datasets/colmap.py b/examples/datasets/colmap.py index 22eacc2ac..9019c82bb 100644 --- a/examples/datasets/colmap.py +++ b/examples/datasets/colmap.py @@ -182,6 +182,7 @@ def __init__( self.image_names = image_names # List[str], (num_images,) self.image_paths = image_paths # List[str], (num_images,) self.camtoworlds = camtoworlds # np.ndarray, (num_images, 4, 4) + self.worldtocams = np.linalg.inv(camtoworlds) # np.ndarray, (num_images, 4, 4) self.camera_ids = camera_ids # List[int], (num_images,) self.Ks_dict = Ks_dict # Dict of camera_id -> K self.params_dict = params_dict # Dict of camera_id -> params @@ -254,6 +255,7 @@ def __getitem__(self, item: int) -> Dict[str, Any]: K = self.parser.Ks_dict[camera_id].copy() # undistorted K params = self.parser.params_dict[camera_id] camtoworlds = self.parser.camtoworlds[index] + worldtocams = self.parser.worldtocams[index] if len(params) > 0: # Images are distorted. Undistort them. @@ -277,6 +279,7 @@ def __getitem__(self, item: int) -> Dict[str, Any]: data = { "K": torch.from_numpy(K).float(), "camtoworld": torch.from_numpy(camtoworlds).float(), + "worldtocam": torch.from_numpy(worldtocams).float(), "image": torch.from_numpy(image).float(), "image_id": item, # the index of the image in the dataset } diff --git a/examples/datasets/download_dataset.py b/examples/datasets/download_dataset.py index 8366ae979..5cdb99ffd 100755 --- a/examples/datasets/download_dataset.py +++ b/examples/datasets/download_dataset.py @@ -9,9 +9,7 @@ import tyro # dataset names -dataset_names = Literal[ - "mipnerf360", -] +dataset_names = Literal["mipnerf360",] # dataset urls urls = {"mipnerf360": "http://storage.googleapis.com/gresearch/refraw360/360_v2.zip"} diff --git a/examples/show_mipnerf360.py b/examples/show_mipnerf360.py new file mode 100644 index 000000000..c89fa124d --- /dev/null +++ b/examples/show_mipnerf360.py @@ -0,0 +1,51 @@ +import json +import numpy as np +import glob + +# 9 scenes +# scenes = ['bicycle', 'flowers', 'garden', 'stump', 'treehill', 'room', 'counter', 'kitchen', 'bonsai'] + +# outdoor scenes +# scenes = scenes[:5] +# indoor scenes +# scenes = scenes[5:] + +# 7 scenes +scenes = ["bicycle", "bonsai", "counter", "garden", "stump", "kitchen", "room"] + +result_dirs = ["results/benchmark"] +result_dirs = ["results/benchmark_antialiased"] +result_dirs = ["results/benchmark_mipsplatting"] +result_dirs = ["results/benchmark_mipsplatting_cuda3D"] + +all_metrics = {"psnr": [], "ssim": [], "lpips": [], "num_GS": []} +print(result_dirs) + +for scene in scenes: + print(scene, end=" ") + for result_dir in result_dirs: + json_files = glob.glob(f"{result_dir}/{scene}/stats/val_step29999.json") + for json_file in json_files: + # print(json_file) + data = json.load(open(json_file)) + # print(data) + + for k in ["psnr", "ssim", "lpips", "num_GS"]: + all_metrics[k].append(data[k]) + print(f"{data[k]:.3f}", end=" ") + print() + +latex = [] +for k in ["psnr", "ssim", "lpips", "num_GS"]: + numbers = np.asarray(all_metrics[k]).mean(axis=0).tolist() + print(numbers) + numbers = [numbers] + if k == "PSNR": + numbers = [f"{x:.2f}" for x in numbers] + elif k == "num_GS": + num = numbers[0] / 1e6 + numbers = [f"{num:.2f}"] + else: + numbers = [f"{x:.3f}" for x in numbers] + latex.extend(numbers) +print(" | ".join(latex)) diff --git a/examples/show_mipnerf360_allscales.py b/examples/show_mipnerf360_allscales.py new file mode 100644 index 000000000..dcbf166ad --- /dev/null +++ b/examples/show_mipnerf360_allscales.py @@ -0,0 +1,45 @@ +import json +import numpy as np +import glob + +# 9 scenes +# scenes = ['bicycle', 'flowers', 'garden', 'stump', 'treehill', 'room', 'counter', 'kitchen', 'bonsai'] + +# outdoor scenes +# scenes = scenes[:5] +# indoor scenes +# scenes = scenes[5:] + +# 7 scenes +scenes = ["bicycle", "bonsai", "counter", "garden", "stump", "kitchen", "room"] + +result_dirs = ["results/benchmark_stmt"] +# result_dirs = ["results/benchmark_antialiased_stmt"] +# result_dirs = ["results/benchmark_mipsplatting_stmt"] + +all_metrics = {"psnr": [], "ssim": [], "lpips": [], "num_GS": []} +print(result_dirs) + +for scene in scenes: + print(scene) + for result_dir in result_dirs: + for scale in ["8", "4", "2", "1"]: + json_files = glob.glob(f"{result_dir}/{scene}_{scale}/stats/val_step29999.json") + for json_file in json_files: + data = json.load(open(json_file)) + for k in ["psnr", "ssim", "lpips", "num_GS"]: + all_metrics[k].append(data[k]) + print(f"{data[k]:.3f}", end=" ") + print() + +latex = [] +for k in ["psnr", "ssim", "lpips"]: + numbers = np.asarray(all_metrics[k]).reshape(-1, 4).mean(axis=0).tolist() + numbers = numbers + [np.mean(numbers)] + print(numbers) + if k == "psnr": + numbers = [f"{x:.2f}" for x in numbers] + else: + numbers = [f"{x:.3f}" for x in numbers] + latex.extend(numbers) +print(" | ".join(latex)) diff --git a/examples/simple_trainer_gof.py b/examples/simple_trainer_gof.py new file mode 100644 index 000000000..50ed8a9e6 --- /dev/null +++ b/examples/simple_trainer_gof.py @@ -0,0 +1,1281 @@ +import json +import math +import os +import time +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Tuple + +import imageio +import numpy as np +import torch +import torch.nn.functional as F +import tqdm +import tyro +import viser +import nerfview +from datasets.colmap import Dataset, Parser +from datasets.traj import generate_interpolated_path +from torch import Tensor +from torch.utils.tensorboard import SummaryWriter +from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure +from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity +from utils import ( + AppearanceOptModule, + CameraOptModule, + knn, + normalized_quat_to_rotmat, + rgb_to_sh, + set_random_seed, + depth_to_normal, + create_tetrahedra_points, +) + +from gsplat.rendering import raytracing, integration +from gsplat import compute_3D_smoothing_filter +from tetranerf.utils.extension import cpp +from gsplat.tetmesh import marching_tetrahedra + +@dataclass +class Config: + # Disable viewer + disable_viewer: bool = False + # Path to the .pt file. If provide, it will skip training and render a video + ckpt: Optional[str] = None + + # Path to the Mip-NeRF 360 dataset + data_dir: str = "data/360_v2/garden" + # Downsample factor for the dataset + data_factor: int = 4 + # Directory to save results + result_dir: str = "results/garden" + # Every N images there is a test image + test_every: int = 8 + # Random crop size for training (experimental) + patch_size: Optional[int] = None + # A global scaler that applies to the scene size related parameters + global_scale: float = 1.0 + + # Port for the viewer server + port: int = 8080 + + # Batch size for training. Learning rates are scaled automatically + batch_size: int = 1 + # A global factor to scale the number of training steps + steps_scaler: float = 1.0 + + # Number of training steps + max_steps: int = 30_000 + # Steps to evaluate the model + eval_steps: List[int] = field(default_factory=lambda: [7_000, 30_000]) + # Steps to save the model + save_steps: List[int] = field(default_factory=lambda: [7_000, 30_000]) + + # Initialization strategy + init_type: str = "sfm" + # Initial number of GSs. Ignored if using sfm + init_num_pts: int = 100_000 + # Initial extent of GSs as a multiple of the camera extent. Ignored if using sfm + init_extent: float = 3.0 + # Degree of spherical harmonics + sh_degree: int = 3 + # Turn on another SH degree every this steps + sh_degree_interval: int = 1000 + # Initial opacity of GS + init_opa: float = 0.1 + # Initial scale of GS + init_scale: float = 1.0 + # Weight for SSIM loss + ssim_lambda: float = 0.2 + + # Near plane clipping distance + near_plane: float = 0.01 + # Far plane clipping distance + far_plane: float = 1e10 + + # GSs with opacity below this value will be pruned + prune_opa: float = 0.005 + # GSs with image plane gradient above this value will be split/duplicated + grow_grad2d: float = 0.0002 + # GSs with scale below this value will be duplicated. Above will be split + grow_scale3d: float = 0.01 + # GSs with scale above this value will be pruned. + prune_scale3d: float = 0.1 + + # Start refining GSs after this iteration + refine_start_iter: int = 500 + # Stop refining GSs after this iteration + refine_stop_iter: int = 15_000 + # Reset opacities every this steps + reset_every: int = 3000 + # Refine GSs every this steps + refine_every: int = 100 + + # Use packed mode for rasterization, this leads to less memory usage but slightly slower. + packed: bool = False + # Use sparse gradients for optimization. (experimental) + sparse_grad: bool = False + # Use absolute gradient for pruning. This typically requires larger --grow_grad2d, e.g., 0.0008 or 0.0006 + absgrad: bool = False + # Anti-aliasing in rasterization. Might slightly hurt quantitative metrics. + antialiased: bool = False + # kernel size for the low-pass filter in rasterization. 0.1 should be a better value since it better approximates a 2D box filter of single pixel size. + kernel_size: float = 0.3 + + # Use random background for training to discourage transparency + random_bkgd: bool = False + + # Enable camera optimization. + pose_opt: bool = False + # Learning rate for camera optimization + pose_opt_lr: float = 1e-5 + # Regularization for camera optimization as weight decay + pose_opt_reg: float = 1e-6 + # Add noise to camera extrinsics. This is only to test the camera pose optimization. + pose_noise: float = 0.0 + + # Enable appearance optimization. (experimental) + app_opt: bool = False + # Appearance embedding dimension + app_embed_dim: int = 16 + # Learning rate for appearance optimization + app_opt_lr: float = 1e-3 + # Regularization for appearance optimization as weight decay + app_opt_reg: float = 1e-6 + + # Enable depth loss. (experimental) + depth_loss: bool = False + # Weight for depth loss + depth_lambda: float = 1e-2 + + # Dump information to tensorboard every this steps + tb_every: int = 100 + # Save training images to tensorboard + tb_save_image: bool = False + + def adjust_steps(self, factor: float): + self.eval_steps = [int(i * factor) for i in self.eval_steps] + self.save_steps = [int(i * factor) for i in self.save_steps] + self.max_steps = int(self.max_steps * factor) + self.sh_degree_interval = int(self.sh_degree_interval * factor) + self.refine_start_iter = int(self.refine_start_iter * factor) + self.refine_stop_iter = int(self.refine_stop_iter * factor) + self.reset_every = int(self.reset_every * factor) + self.refine_every = int(self.refine_every * factor) + + +def create_splats_with_optimizers( + parser: Parser, + init_type: str = "sfm", + init_num_pts: int = 100_000, + init_extent: float = 3.0, + init_opacity: float = 0.1, + init_scale: float = 1.0, + scene_scale: float = 1.0, + sh_degree: int = 3, + sparse_grad: bool = False, + batch_size: int = 1, + feature_dim: Optional[int] = None, + device: str = "cuda", +) -> Tuple[torch.nn.ParameterDict, torch.optim.Optimizer]: + if init_type == "sfm": + points = torch.from_numpy(parser.points).float() + rgbs = torch.from_numpy(parser.points_rgb / 255.0).float() + elif init_type == "random": + points = init_extent * scene_scale * (torch.rand((init_num_pts, 3)) * 2 - 1) + rgbs = torch.rand((init_num_pts, 3)) + else: + raise ValueError("Please specify a correct init_type: sfm or random") + + N = points.shape[0] + # Initialize the GS size to be the average dist of the 3 nearest neighbors + dist2_avg = (knn(points, 4)[:, 1:] ** 2).mean(dim=-1) # [N,] + dist_avg = torch.sqrt(dist2_avg) + scales = torch.log(dist_avg * init_scale).unsqueeze(-1).repeat(1, 3) # [N, 3] + quats = torch.rand((N, 4)) # [N, 4] + opacities = torch.logit(torch.full((N,), init_opacity)) # [N,] + + params = [ + # name, value, lr + ("means3d", torch.nn.Parameter(points), 1.6e-4 * scene_scale), + ("scales", torch.nn.Parameter(scales), 5e-3), + ("quats", torch.nn.Parameter(quats), 1e-3), + ("opacities", torch.nn.Parameter(opacities), 5e-2), + # 3D smoothing filter, setting lr to 0.0 to disable optimization + ("filters", torch.nn.Parameter(torch.ones_like(opacities)), 0.0), + ] + + if feature_dim is None: + # color is SH coefficients. + colors = torch.zeros((N, (sh_degree + 1) ** 2, 3)) # [N, K, 3] + colors[:, 0, :] = rgb_to_sh(rgbs) + params.append(("sh0", torch.nn.Parameter(colors[:, :1, :]), 2.5e-3)) + params.append(("shN", torch.nn.Parameter(colors[:, 1:, :]), 2.5e-3 / 20)) + else: + # features will be used for appearance and view-dependent shading + features = torch.rand(N, feature_dim) # [N, feature_dim] + params.append(("features", torch.nn.Parameter(features), 2.5e-3)) + colors = torch.logit(rgbs) # [N, 3] + params.append(("colors", torch.nn.Parameter(colors), 2.5e-3)) + + splats = torch.nn.ParameterDict({n: v for n, v, _ in params}).to(device) + # Scale learning rate based on batch size, reference: + # https://www.cs.princeton.edu/~smalladi/blog/2024/01/22/SDEs-ScalingRules/ + # Note that this would not make the training exactly equivalent, see + # https://arxiv.org/pdf/2402.18824v1 + optimizers = [ + (torch.optim.SparseAdam if sparse_grad else torch.optim.Adam)( + [{"params": splats[name], "lr": lr * math.sqrt(batch_size), "name": name}], + eps=1e-15 / math.sqrt(batch_size), + betas=(1 - batch_size * (1 - 0.9), 1 - batch_size * (1 - 0.999)), + ) + for name, _, lr in params + ] + return splats, optimizers + + +class Runner: + """Engine for training and testing.""" + + def __init__(self, cfg: Config) -> None: + set_random_seed(42) + + self.cfg = cfg + self.device = "cuda" + + # Where to dump results. + os.makedirs(cfg.result_dir, exist_ok=True) + + # Setup output directories. + self.ckpt_dir = f"{cfg.result_dir}/ckpts" + os.makedirs(self.ckpt_dir, exist_ok=True) + self.stats_dir = f"{cfg.result_dir}/stats" + os.makedirs(self.stats_dir, exist_ok=True) + self.render_dir = f"{cfg.result_dir}/renders" + os.makedirs(self.render_dir, exist_ok=True) + + # Tensorboard + self.writer = SummaryWriter(log_dir=f"{cfg.result_dir}/tb") + + # Load data: Training data should contain initial points and colors. + self.parser = Parser( + data_dir=cfg.data_dir, + factor=cfg.data_factor, + normalize=True, + test_every=cfg.test_every, + ) + self.trainset = Dataset( + self.parser, + split="train", + patch_size=cfg.patch_size, + load_depths=cfg.depth_loss, + ) + self.valset = Dataset(self.parser, split="val") + self.scene_scale = self.parser.scene_scale * 1.1 * cfg.global_scale + print("Scene scale:", self.scene_scale) + + # Model + feature_dim = 32 if cfg.app_opt else None + self.splats, self.optimizers = create_splats_with_optimizers( + self.parser, + init_type=cfg.init_type, + init_num_pts=cfg.init_num_pts, + init_extent=cfg.init_extent, + init_opacity=cfg.init_opa, + init_scale=cfg.init_scale, + scene_scale=self.scene_scale, + sh_degree=cfg.sh_degree, + sparse_grad=cfg.sparse_grad, + batch_size=cfg.batch_size, + feature_dim=feature_dim, + device=self.device, + ) + print("Model initialized. Number of GS:", len(self.splats["means3d"])) + + self.pose_optimizers = [] + if cfg.pose_opt: + self.pose_adjust = CameraOptModule(len(self.trainset)).to(self.device) + self.pose_adjust.zero_init() + self.pose_optimizers = [ + torch.optim.Adam( + self.pose_adjust.parameters(), + lr=cfg.pose_opt_lr * math.sqrt(cfg.batch_size), + weight_decay=cfg.pose_opt_reg, + ) + ] + + if cfg.pose_noise > 0.0: + self.pose_perturb = CameraOptModule(len(self.trainset)).to(self.device) + self.pose_perturb.random_init(cfg.pose_noise) + + self.app_optimizers = [] + if cfg.app_opt: + self.app_module = AppearanceOptModule( + len(self.trainset), feature_dim, cfg.app_embed_dim, cfg.sh_degree + ).to(self.device) + # initialize the last layer to be zero so that the initial output is zero. + torch.nn.init.zeros_(self.app_module.color_head[-1].weight) + torch.nn.init.zeros_(self.app_module.color_head[-1].bias) + self.app_optimizers = [ + torch.optim.Adam( + self.app_module.embeds.parameters(), + lr=cfg.app_opt_lr * math.sqrt(cfg.batch_size) * 10.0, + weight_decay=cfg.app_opt_reg, + ), + torch.optim.Adam( + self.app_module.color_head.parameters(), + lr=cfg.app_opt_lr * math.sqrt(cfg.batch_size), + ), + ] + + # Losses & Metrics. + self.ssim = StructuralSimilarityIndexMeasure(data_range=1.0).to(self.device) + self.psnr = PeakSignalNoiseRatio(data_range=1.0).to(self.device) + self.lpips = LearnedPerceptualImagePatchSimilarity(normalize=True).to( + self.device + ) + + # Viewer + if not self.cfg.disable_viewer: + self.server = viser.ViserServer(port=cfg.port, verbose=False) + self.viewer = nerfview.Viewer( + server=self.server, + render_fn=self._viewer_render_fn, + mode="training", + ) + + # Running stats for prunning & growing. + n_gauss = len(self.splats["means3d"]) + self.running_stats = { + "grad2d": torch.zeros(n_gauss, device=self.device), # norm of the gradient + "count": torch.zeros(n_gauss, device=self.device, dtype=torch.int), + } + + def get_scale_opacity_with_smoothing_filer(self): + scales = torch.exp(self.splats["scales"]) # [N, 3] + opacities = torch.sigmoid(self.splats["opacities"]) # [N,] + filters = self.splats["filters"] # [N,] + + # apply 3D smoothing filter to scales and opacities + scales_square = torch.square(scales) # [N, 3] + det1 = scales_square.prod(dim=1) # [N, ] + + scales_after_square = scales_square + torch.square(filters)[:, None] # [N, 1] + det2 = scales_after_square.prod(dim=1) # [N,] + coef = torch.sqrt(det1 / det2 + 1e-7) # [N,] + opacities = opacities * coef + + scales = torch.square(scales) + torch.square(filters)[:, None] # [N, 3] + scales = torch.sqrt(scales) + return scales, opacities + + def rasterize_splats( + self, + camtoworlds: Tensor, + Ks: Tensor, + width: int, + height: int, + **kwargs, + ) -> Tuple[Tensor, Tensor, Dict]: + means = self.splats["means3d"] # [N, 3] + # quats = F.normalize(self.splats["quats"], dim=-1) # [N, 4] + # rasterization does normalization internally + quats = self.splats["quats"] # [N, 4] + + scales, opacities = self.get_scale_opacity_with_smoothing_filer() + + image_ids = kwargs.pop("image_ids", None) + if self.cfg.app_opt: + colors = self.app_module( + features=self.splats["features"], + embed_ids=image_ids, + dirs=means[None, :, :] - camtoworlds[:, None, :3, 3], + sh_degree=kwargs.pop("sh_degree", self.cfg.sh_degree), + ) + colors = colors + self.splats["colors"] + colors = torch.sigmoid(colors) + else: + colors = torch.cat([self.splats["sh0"], self.splats["shN"]], 1) # [N, K, 3] + + rasterize_mode = "antialiased" if self.cfg.antialiased else "classic" + render_colors, render_alphas, info = raytracing( + means=means, + quats=quats, + scales=scales, + opacities=opacities, + colors=colors, + viewmats=torch.linalg.inv(camtoworlds), # [C, 4, 4] + Ks=Ks, # [C, 3, 3] + width=width, + height=height, + packed=self.cfg.packed, + absgrad=self.cfg.absgrad, + sparse_grad=self.cfg.sparse_grad, + rasterize_mode=rasterize_mode, + eps2d=self.cfg.kernel_size, + **kwargs, + ) + return render_colors, render_alphas, info + + def train(self): + cfg = self.cfg + device = self.device + + # Dump cfg. + with open(f"{cfg.result_dir}/cfg.json", "w") as f: + json.dump(vars(cfg), f) + + max_steps = cfg.max_steps + init_step = 0 + + schedulers = [ + # means3d has a learning rate schedule, that end at 0.01 of the initial value + torch.optim.lr_scheduler.ExponentialLR( + self.optimizers[0], gamma=0.01 ** (1.0 / max_steps) + ), + ] + if cfg.pose_opt: + # pose optimization has a learning rate schedule + schedulers.append( + torch.optim.lr_scheduler.ExponentialLR( + self.pose_optimizers[0], gamma=0.01 ** (1.0 / max_steps) + ) + ) + + trainloader = torch.utils.data.DataLoader( + self.trainset, + batch_size=cfg.batch_size, + shuffle=True, + num_workers=4, + persistent_workers=True, + pin_memory=True, + ) + trainloader_iter = iter(trainloader) + + # determine the 3D smoothing filter before training + self.compute_3D_smoothing_filter() + + # Training loop. + global_tic = time.time() + pbar = tqdm.tqdm(range(init_step, max_steps)) + for step in pbar: + if not cfg.disable_viewer: + while self.viewer.state.status == "paused": + time.sleep(0.01) + self.viewer.lock.acquire() + tic = time.time() + + try: + data = next(trainloader_iter) + except StopIteration: + trainloader_iter = iter(trainloader) + data = next(trainloader_iter) + + camtoworlds = camtoworlds_gt = data["camtoworld"].to(device) # [1, 4, 4] + Ks = data["K"].to(device) # [1, 3, 3] + pixels = data["image"].to(device) / 255.0 # [1, H, W, 3] + num_train_rays_per_step = ( + pixels.shape[0] * pixels.shape[1] * pixels.shape[2] + ) + image_ids = data["image_id"].to(device) + if cfg.depth_loss: + points = data["points"].to(device) # [1, M, 2] + depths_gt = data["depths"].to(device) # [1, M] + + height, width = pixels.shape[1:3] + + if cfg.pose_noise: + camtoworlds = self.pose_perturb(camtoworlds, image_ids) + + if cfg.pose_opt: + camtoworlds = self.pose_adjust(camtoworlds, image_ids) + + # sh schedule + sh_degree_to_use = min(step // cfg.sh_degree_interval, cfg.sh_degree) + + # forward + renders, alphas, info = self.rasterize_splats( + camtoworlds=camtoworlds, + Ks=Ks, + width=width, + height=height, + sh_degree=sh_degree_to_use, + near_plane=cfg.near_plane, + far_plane=cfg.far_plane, + image_ids=image_ids, + render_mode="RGB+ED" if cfg.depth_loss else "RGB", + ) + + if renders.shape[-1] == 7: + colors, normals, depths = ( + renders[..., 0:3], + renders[..., 3:6], + renders[..., 6:7], + ) + else: + colors, depths = renders, None + + if cfg.random_bkgd: + bkgd = torch.rand(1, 3, device=device) + colors = colors + bkgd * (1.0 - alphas) + + info["means2d"].retain_grad() # used for running stats + + # loss + l1loss = F.l1_loss(colors, pixels) + ssimloss = 1.0 - self.ssim( + pixels.permute(0, 3, 1, 2), colors.permute(0, 3, 1, 2) + ) + loss = l1loss * (1.0 - cfg.ssim_lambda) + ssimloss * cfg.ssim_lambda + if cfg.depth_loss: + # query depths from depth map + points = torch.stack( + [ + points[:, :, 0] / (width - 1) * 2 - 1, + points[:, :, 1] / (height - 1) * 2 - 1, + ], + dim=-1, + ) # normalize to [-1, 1] + grid = points.unsqueeze(2) # [1, M, 1, 2] + depths = F.grid_sample( + depths.permute(0, 3, 1, 2), grid, align_corners=True + ) # [1, 1, M, 1] + depths = depths.squeeze(3).squeeze(1) # [1, M] + # calculate loss in disparity space + disp = torch.where(depths > 0.0, 1.0 / depths, torch.zeros_like(depths)) + disp_gt = 1.0 / depths_gt # [1, M] + depthloss = F.l1_loss(disp, disp_gt) * self.scene_scale + loss += depthloss * cfg.depth_lambda + + # depth normal consistency loss + normals = F.normalize(normals, dim=-1) + depth_normals = depth_to_normal(Ks[0], depths[0, ..., 0]) + + normal_error = 1 - (normals[0] * depth_normals).sum(dim=-1) + depth_normal_loss = normal_error.mean() + if step > 15000: + loss += depth_normal_loss * 0.05 + + loss.backward() + + desc = f"loss={loss.item():.3f}| " f"sh degree={sh_degree_to_use}| " + if cfg.depth_loss: + desc += f"depth loss={depthloss.item():.6f}| " + if cfg.pose_opt and cfg.pose_noise: + # monitor the pose error if we inject noise + pose_err = F.l1_loss(camtoworlds_gt, camtoworlds) + desc += f"pose err={pose_err.item():.6f}| " + pbar.set_description(desc) + + # save image to debug + if step % 100 == 0: + normals_vis = normals * 0.5 + 0.5 + depth_normals_vis = depth_normals * 0.5 + 0.5 + depths_vis = (depths - depths.min()) / ( + depths.max() - depths.min() + ).repeat(1, 1, 1, 3) + # breakpoint() + colors = torch.clamp(colors, 0, 1) + canvas = ( + torch.cat( + [ + pixels, + colors, + depths_vis, + normals_vis, + depth_normals_vis[None], + ], + dim=-2, + ) + .detach() + .cpu() + .numpy() + ) + canvas = canvas.reshape(-1, *canvas.shape[2:]) + imageio.imwrite( + f"{cfg.result_dir}/train_{step:04d}.png", + (canvas * 255).astype(np.uint8), + ) + + if cfg.tb_every > 0 and step % cfg.tb_every == 0: + mem = torch.cuda.max_memory_allocated() / 1024**3 + self.writer.add_scalar("train/loss", loss.item(), step) + self.writer.add_scalar("train/l1loss", l1loss.item(), step) + self.writer.add_scalar("train/ssimloss", ssimloss.item(), step) + self.writer.add_scalar( + "train/num_GS", len(self.splats["means3d"]), step + ) + self.writer.add_scalar("train/mem", mem, step) + if cfg.depth_loss: + self.writer.add_scalar("train/depthloss", depthloss.item(), step) + if cfg.tb_save_image: + canvas = torch.cat([pixels, colors], dim=2).detach().cpu().numpy() + canvas = canvas.reshape(-1, *canvas.shape[2:]) + self.writer.add_image("train/render", canvas, step) + self.writer.flush() + + # update running stats for prunning & growing + if step < cfg.refine_stop_iter: + self.update_running_stats(info) + + if step > cfg.refine_start_iter and step % cfg.refine_every == 0: + grads = self.running_stats["grad2d"] / self.running_stats[ + "count" + ].clamp_min(1) + + # grow GSs + is_grad_high = grads >= cfg.grow_grad2d + is_small = ( + torch.exp(self.splats["scales"]).max(dim=-1).values + <= cfg.grow_scale3d * self.scene_scale + ) + is_dupli = is_grad_high & is_small + n_dupli = is_dupli.sum().item() + self.refine_duplicate(is_dupli) + + is_split = is_grad_high & ~is_small + is_split = torch.cat( + [ + is_split, + # new GSs added by duplication will not be split + torch.zeros(n_dupli, device=device, dtype=torch.bool), + ] + ) + n_split = is_split.sum().item() + self.refine_split(is_split) + print( + f"Step {step}: {n_dupli} GSs duplicated, {n_split} GSs split. " + f"Now having {len(self.splats['means3d'])} GSs." + ) + + # prune GSs + is_prune = torch.sigmoid(self.splats["opacities"]) < (cfg.prune_opa * 10.) + if step > cfg.reset_every: + # The official code also implements sreen-size pruning but + # it's actually not being used due to a bug: + # https://github.com/graphdeco-inria/gaussian-splatting/issues/123 + is_too_big = ( + torch.exp(self.splats["scales"]).max(dim=-1).values + > cfg.prune_scale3d * self.scene_scale + ) + is_prune = is_prune | is_too_big + n_prune = is_prune.sum().item() + self.refine_keep(~is_prune) + print( + f"Step {step}: {n_prune} GSs pruned. " + f"Now having {len(self.splats['means3d'])} GSs." + ) + self.compute_3D_smoothing_filter() + + # reset running stats + self.running_stats["grad2d"].zero_() + self.running_stats["count"].zero_() + + if step % cfg.reset_every == 0: + self.reset_opa(cfg.prune_opa * 2.0) + + # Turn Gradients into Sparse Tensor before running optimizer + if cfg.sparse_grad: + assert cfg.packed, "Sparse gradients only work with packed mode." + gaussian_ids = info["gaussian_ids"] + for k in self.splats.keys(): + grad = self.splats[k].grad + if grad is None or grad.is_sparse: + continue + self.splats[k].grad = torch.sparse_coo_tensor( + indices=gaussian_ids[None], # [1, nnz] + values=grad[gaussian_ids], # [nnz, ...] + size=self.splats[k].size(), # [N, ...] + is_coalesced=len(Ks) == 1, + ) + + # optimize + for optimizer in self.optimizers: + optimizer.step() + optimizer.zero_grad(set_to_none=True) + for optimizer in self.pose_optimizers: + optimizer.step() + optimizer.zero_grad(set_to_none=True) + for optimizer in self.app_optimizers: + optimizer.step() + optimizer.zero_grad(set_to_none=True) + for scheduler in schedulers: + scheduler.step() + + # save checkpoint + if step in [i - 1 for i in cfg.save_steps] or step == max_steps - 1: + mem = torch.cuda.max_memory_allocated() / 1024**3 + stats = { + "mem": mem, + "ellipse_time": time.time() - global_tic, + "num_GS": len(self.splats["means3d"]), + } + print("Step: ", step, stats) + with open(f"{self.stats_dir}/train_step{step:04d}.json", "w") as f: + json.dump(stats, f) + torch.save( + { + "step": step, + "splats": self.splats.state_dict(), + }, + f"{self.ckpt_dir}/ckpt_{step}.pt", + ) + + # eval the full set + if step in [i - 1 for i in cfg.eval_steps] or step == max_steps - 1: + self.eval(step) + self.render_traj(step) + + if not cfg.disable_viewer: + self.viewer.lock.release() + num_train_steps_per_sec = 1.0 / (time.time() - tic) + num_train_rays_per_sec = ( + num_train_rays_per_step * num_train_steps_per_sec + ) + # Update the viewer state. + self.viewer.state.num_train_rays_per_sec = num_train_rays_per_sec + # Update the scene. + self.viewer.update(step, num_train_rays_per_step) + + @torch.no_grad() + def compute_3D_smoothing_filter(self): + cfg = self.cfg + device = self.device + xyz = self.splats["means3d"] + print("xyz", xyz.shape, xyz.device) + + worldtocams = ( + torch.from_numpy(self.trainset.parser.worldtocams).float().to(device) + ) + # TODO, currently use K, H, W of the first image + data = self.trainset[0] + K = data["K"].to(device) # [3, 3] + height, width = data["image"].shape[:2] + Ks = torch.stack([K] * len(worldtocams), dim=0) # [C, 3, 3] + filter_3D = compute_3D_smoothing_filter( + xyz, worldtocams, Ks, width, height, cfg.near_plane + ) + valid_points = filter_3D < 1000000.0 + filter_3D[~valid_points] = filter_3D[valid_points].max() + # 0.3 since we don't use anti-aliasing for gof at the moment + filter_3D = filter_3D * (0.3**0.5) + self.splats["filters"] = torch.nn.Parameter(filter_3D) + + @torch.no_grad() + def compute_3D_smoothing_filter_torch(self): + print("Computing 3D filter") + cfg = self.cfg + device = self.device + xyz = self.splats["means3d"] + print("xyz", xyz.shape, xyz.device) + + distance = torch.ones((xyz.shape[0]), device=xyz.device) * 100000.0 + valid_points = torch.zeros((xyz.shape[0]), device=xyz.device, dtype=torch.bool) + focal_length = 0.0 + + for data in self.trainset: + worldtocam = data["worldtocam"].to(device) # [4, 4] + K = data["K"].to(device) # [3, 3] + height, width = data["image"].shape[:2] + R = worldtocam[:3, :3] + T = worldtocam[:3, 3] + + xyz_cam = xyz @ R.transpose(1, 0) + T[None, :] + + # project to screen space + valid_depth = xyz_cam[:, 2] > cfg.near_plane + + x, y, z = xyz_cam[:, 0], xyz_cam[:, 1], xyz_cam[:, 2] + z = torch.clamp(z, min=0.001) + + x = x / z * K[0, 0] + K[0, 2] + y = y / z * K[1, 1] + K[1, 2] + + # use similar tangent space filtering as in 3DGS, + # TODO check gsplat's implementation + in_screen = torch.logical_and( + torch.logical_and(x >= -0.15 * width, x <= width * 1.15), + torch.logical_and(y >= -0.15 * height, y <= 1.15 * height), + ) + valid = torch.logical_and(valid_depth, in_screen) + + distance[valid] = torch.min(distance[valid], z[valid]) + valid_points = torch.logical_or(valid_points, valid) + if focal_length < K[0, 0]: + focal_length = K[0, 0] + + distance[~valid_points] = distance[valid_points].max() + + filter_3D = distance / focal_length * (0.2**0.5) + self.splats["filters"] = torch.nn.Parameter(filter_3D) + + @torch.no_grad() + def update_running_stats(self, info: Dict): + """Update running stats.""" + cfg = self.cfg + + # normalize grads to [-1, 1] screen space + if cfg.absgrad: + grads = info["means2d"].absgrad.clone() + else: + grads = info["means2d"].grad.clone() + grads[..., 0] *= info["width"] / 2.0 * cfg.batch_size + grads[..., 1] *= info["height"] / 2.0 * cfg.batch_size + if cfg.packed: + # grads is [nnz, 2] + gs_ids = info["gaussian_ids"] # [nnz] or None + self.running_stats["grad2d"].index_add_(0, gs_ids, grads.norm(dim=-1)) + self.running_stats["count"].index_add_( + 0, gs_ids, torch.ones_like(gs_ids).int() + ) + else: + # grads is [C, N, 2] + sel = info["radii"] > 0.0 # [C, N] + gs_ids = torch.where(sel)[1] # [nnz] + self.running_stats["grad2d"].index_add_(0, gs_ids, grads[sel].norm(dim=-1)) + self.running_stats["count"].index_add_( + 0, gs_ids, torch.ones_like(gs_ids).int() + ) + + @torch.no_grad() + def reset_opa(self, value: float = 0.01): + """Utility function to reset opacities.""" + # opacities = torch.clamp( + # self.splats["opacities"], max=torch.logit(torch.tensor(value)).item() + # ) + + scales = torch.exp(self.splats["scales"]) # [N, 3] + opacities = torch.sigmoid(self.splats["opacities"]) # [N,] + filters = self.splats["filters"] # [N,] + + # apply 3D smoothing filter to scales and opacities + print("apply 3D smoothing filter in reset opacities") + scales_square = torch.square(scales) # [N, 3] + det1 = scales_square.prod(dim=1) # [N, ] + + scales_after_square = scales_square + torch.square(filters)[:, None] # [N, 1] + det2 = scales_after_square.prod(dim=1) # [N,] + coef = torch.sqrt(det1 / det2 + 1e-7) # [N,] + opacities = opacities * coef + + opacities_reset = torch.min(opacities, torch.ones_like(opacities) * value) + opacities_reset = opacities_reset / (coef + 1e-7) + opacities = torch.logit(opacities_reset) + + for optimizer in self.optimizers: + for i, param_group in enumerate(optimizer.param_groups): + if param_group["name"] != "opacities": + continue + p = param_group["params"][0] + p_state = optimizer.state[p] + del optimizer.state[p] + for key in p_state.keys(): + if key != "step": + p_state[key] = torch.zeros_like(p_state[key]) + p_new = torch.nn.Parameter(opacities) + optimizer.param_groups[i]["params"] = [p_new] + optimizer.state[p_new] = p_state + self.splats[param_group["name"]] = p_new + torch.cuda.empty_cache() + + @torch.no_grad() + def refine_split(self, mask: Tensor): + """Utility function to grow GSs.""" + device = self.device + + sel = torch.where(mask)[0] + rest = torch.where(~mask)[0] + + scales = torch.exp(self.splats["scales"][sel]) # [N, 3] + quats = F.normalize(self.splats["quats"][sel], dim=-1) # [N, 4] + rotmats = normalized_quat_to_rotmat(quats) # [N, 3, 3] + samples = torch.einsum( + "nij,nj,bnj->bni", + rotmats, + scales, + torch.randn(2, len(scales), 3, device=device), + ) # [2, N, 3] + + for optimizer in self.optimizers: + for i, param_group in enumerate(optimizer.param_groups): + p = param_group["params"][0] + name = param_group["name"] + # create new params + if name == "means3d": + p_split = (p[sel] + samples).reshape(-1, 3) # [2N, 3] + elif name == "scales": + p_split = torch.log(scales / 1.6).repeat(2, 1) # [2N, 3] + else: + repeats = [2] + [1] * (p.dim() - 1) + p_split = p[sel].repeat(repeats) + p_new = torch.cat([p[rest], p_split]) + p_new = torch.nn.Parameter(p_new) + # update optimizer + p_state = optimizer.state[p] + del optimizer.state[p] + for key in p_state.keys(): + if key == "step": + continue + v = p_state[key] + # new params are assigned with zero optimizer states + # (worth investigating it) + v_split = torch.zeros((2 * len(sel), *v.shape[1:]), device=device) + p_state[key] = torch.cat([v[rest], v_split]) + optimizer.param_groups[i]["params"] = [p_new] + optimizer.state[p_new] = p_state + self.splats[name] = p_new + for k, v in self.running_stats.items(): + if v is None: + continue + repeats = [2] + [1] * (v.dim() - 1) + v_new = v[sel].repeat(repeats) + self.running_stats[k] = torch.cat((v[rest], v_new)) + torch.cuda.empty_cache() + + @torch.no_grad() + def refine_duplicate(self, mask: Tensor): + """Unility function to duplicate GSs.""" + sel = torch.where(mask)[0] + for optimizer in self.optimizers: + for i, param_group in enumerate(optimizer.param_groups): + p = param_group["params"][0] + name = param_group["name"] + p_state = optimizer.state[p] + del optimizer.state[p] + for key in p_state.keys(): + if key != "step": + # new params are assigned with zero optimizer states + # (worth investigating it as it will lead to a lot more GS.) + v = p_state[key] + v_new = torch.zeros( + (len(sel), *v.shape[1:]), device=self.device + ) + # v_new = v[sel] + p_state[key] = torch.cat([v, v_new]) + p_new = torch.nn.Parameter(torch.cat([p, p[sel]])) + optimizer.param_groups[i]["params"] = [p_new] + optimizer.state[p_new] = p_state + self.splats[name] = p_new + for k, v in self.running_stats.items(): + self.running_stats[k] = torch.cat((v, v[sel])) + torch.cuda.empty_cache() + + @torch.no_grad() + def refine_keep(self, mask: Tensor): + """Unility function to prune GSs.""" + sel = torch.where(mask)[0] + for optimizer in self.optimizers: + for i, param_group in enumerate(optimizer.param_groups): + p = param_group["params"][0] + name = param_group["name"] + p_state = optimizer.state[p] + del optimizer.state[p] + for key in p_state.keys(): + if key != "step": + p_state[key] = p_state[key][sel] + p_new = torch.nn.Parameter(p[sel]) + optimizer.param_groups[i]["params"] = [p_new] + optimizer.state[p_new] = p_state + self.splats[name] = p_new + for k, v in self.running_stats.items(): + self.running_stats[k] = v[sel] + torch.cuda.empty_cache() + + @torch.no_grad() + def eval(self, step: int): + """Entry for evaluation.""" + print("Running evaluation...") + cfg = self.cfg + device = self.device + + valloader = torch.utils.data.DataLoader( + self.valset, batch_size=1, shuffle=False, num_workers=1 + ) + ellipse_time = 0 + metrics = {"psnr": [], "ssim": [], "lpips": []} + for i, data in enumerate(valloader): + camtoworlds = data["camtoworld"].to(device) + Ks = data["K"].to(device) + pixels = data["image"].to(device) / 255.0 + height, width = pixels.shape[1:3] + + torch.cuda.synchronize() + tic = time.time() + renders, _, _ = self.rasterize_splats( + camtoworlds=camtoworlds, + Ks=Ks, + width=width, + height=height, + sh_degree=cfg.sh_degree, + near_plane=cfg.near_plane, + far_plane=cfg.far_plane, + ) # [1, H, W, 3] + colors = torch.clamp(renders[..., 0:3], 0.0, 1.0) # [N, H, W, 3] + # normals = renders[..., 3:6] # [N, H, W, 3] + # depths = renders[..., 6:7] # [N, H, W, 1] + # depths = (depths - depths.min()) / (depths.max() - depths.min()) + + torch.cuda.synchronize() + ellipse_time += time.time() - tic + + # write images + canvas = torch.cat([pixels, colors], dim=2).squeeze(0).cpu().numpy() + imageio.imwrite( + f"{self.render_dir}/val_{i:04d}.png", (canvas * 255).astype(np.uint8) + ) + + pixels = pixels.permute(0, 3, 1, 2) # [1, 3, H, W] + colors = colors.permute(0, 3, 1, 2) # [1, 3, H, W] + metrics["psnr"].append(self.psnr(colors, pixels)) + metrics["ssim"].append(self.ssim(colors, pixels)) + metrics["lpips"].append(self.lpips(colors, pixels)) + + ellipse_time /= len(valloader) + + psnr = torch.stack(metrics["psnr"]).mean() + ssim = torch.stack(metrics["ssim"]).mean() + lpips = torch.stack(metrics["lpips"]).mean() + print( + f"PSNR: {psnr.item():.3f}, SSIM: {ssim.item():.4f}, LPIPS: {lpips.item():.3f} " + f"Time: {ellipse_time:.3f}s/image " + f"Number of GS: {len(self.splats['means3d'])}" + ) + # save stats as json + stats = { + "psnr": psnr.item(), + "ssim": ssim.item(), + "lpips": lpips.item(), + "ellipse_time": ellipse_time, + "num_GS": len(self.splats["means3d"]), + } + with open(f"{self.stats_dir}/val_step{step:04d}.json", "w") as f: + json.dump(stats, f) + # save stats to tensorboard + for k, v in stats.items(): + self.writer.add_scalar(f"val/{k}", v, step) + self.writer.flush() + + @torch.no_grad() + def render_traj(self, step: int): + """Entry for trajectory rendering.""" + print("Running trajectory rendering...") + cfg = self.cfg + device = self.device + + camtoworlds = self.parser.camtoworlds[5:-5] + camtoworlds = generate_interpolated_path(camtoworlds, 1) # [N, 3, 4] + camtoworlds = np.concatenate( + [ + camtoworlds, + np.repeat(np.array([[[0.0, 0.0, 0.0, 1.0]]]), len(camtoworlds), axis=0), + ], + axis=1, + ) # [N, 4, 4] + + camtoworlds = torch.from_numpy(camtoworlds).float().to(device) + K = torch.from_numpy(list(self.parser.Ks_dict.values())[0]).float().to(device) + width, height = list(self.parser.imsize_dict.values())[0] + + canvas_all = [] + for i in tqdm.trange(len(camtoworlds), desc="Rendering trajectory"): + renders, _, _ = self.rasterize_splats( + camtoworlds=camtoworlds[i : i + 1], + Ks=K[None], + width=width, + height=height, + sh_degree=cfg.sh_degree, + near_plane=cfg.near_plane, + far_plane=cfg.far_plane, + render_mode="RGB+ED", + ) # [1, H, W, 4] + colors = torch.clamp(renders[0, ..., 0:3], 0.0, 1.0) # [H, W, 3] + normals = renders[0, ..., 3:6] # [H, W, 3] + depths = renders[0, ..., 6:7] # [H, W, 1] + depths = (depths - depths.min()) / (depths.max() - depths.min()) + + normals = F.normalize(normals, dim=-1) + normals = (normals + 1.0) / 2.0 + # write images + canvas = torch.cat( + [colors, normals, depths.repeat(1, 1, 3)], + dim=0 if width > height else 1, + ) + canvas = (canvas.cpu().numpy() * 255).astype(np.uint8) + canvas_all.append(canvas) + + # save to video + video_dir = f"{cfg.result_dir}/videos" + os.makedirs(video_dir, exist_ok=True) + writer = imageio.get_writer(f"{video_dir}/traj_{step}.mp4", fps=30) + for canvas in canvas_all: + writer.append_data(canvas) + writer.close() + print(f"Video saved to {video_dir}/traj_{step}.mp4") + + @torch.no_grad() + def _viewer_render_fn( + self, camera_state: nerfview.CameraState, img_wh: Tuple[int, int] + ): + """Callable function for the viewer.""" + W, H = img_wh + c2w = camera_state.c2w + K = camera_state.get_K(img_wh) + c2w = torch.from_numpy(c2w).float().to(self.device) + K = torch.from_numpy(K).float().to(self.device) + + renders, _, _ = self.rasterize_splats( + camtoworlds=c2w[None], + Ks=K[None], + width=W, + height=H, + sh_degree=self.cfg.sh_degree, # active all SH degrees + radius_clip=3.0, # skip GSs that have small image radius (in pixels) + ) # [1, H, W, 3] + colors = torch.clamp(renders[..., 0:3], 0.0, 1.0) + return colors[0].cpu().numpy() + + @torch.no_grad() + def evaluate_alpha(self, points, **kwargs): + device = self.device + means = self.splats["means3d"] # [N, 3] + quats = self.splats["quats"] # [N, 4] + scales, opacities = self.get_scale_opacity_with_smoothing_filer() + colors = torch.cat([self.splats["sh0"], self.splats["shN"]], 1) # [N, K, 3] + + trainloader = torch.utils.data.DataLoader( + self.trainset, batch_size=1, shuffle=False, num_workers=1 + ) + + final_alpha = torch.ones((points.shape[0]), dtype=torch.float32, device=device) + + for i, data in enumerate(trainloader): + print(i) + camtoworlds = data["camtoworld"].to(device) + Ks = data["K"].to(device) + pixels = data["image"].to(device) / 255.0 + height, width = pixels.shape[1:3] + + render_colors, render_alphas, info = integration( + points=points, + means=means, + quats=quats, + scales=scales, + opacities=opacities, + colors=colors, + viewmats=torch.linalg.inv(camtoworlds), # [C, 4, 4] + Ks=Ks, # [C, 3, 3] + width=width, + height=height, + sh_degree=self.cfg.sh_degree, + **kwargs, + ) + # out_colors = torch.clamp(render_colors[..., :3], 0.0, 1.0) + # # write images + # canvas = torch.cat([pixels, out_colors], dim=2).squeeze(0).cpu().numpy() + # imageio.imwrite( + # f"{self.render_dir}/train_{i:04d}.png", (canvas * 255).astype(np.uint8) + # ) + # breakpoint() + assert render_alphas.shape[0] == 1 + final_alpha = torch.min(final_alpha, render_alphas.reshape(-1)) + print(final_alpha.mean()) + # break + alpha = 1 - final_alpha + return alpha + + @torch.no_grad() + def extract_mesh(self, **kwargs): + import trimesh + device = self.device + means = self.splats["means3d"] # [N, 3] + quats = self.splats["quats"] # [N, 4] + scales, opacities = self.get_scale_opacity_with_smoothing_filer() + points = create_tetrahedra_points(quats, means, scales) + print(points.shape, means.shape) + print("create cells and save") + cells = cpp.triangulate(points) + # print("save cells", cells.shape) + torch.save(cells, f"{cfg.result_dir}/cells.pt") + # load cells + print("load cells") + cells = torch.load(f"{cfg.result_dir}/cells.pt") + tets = cells.cuda().long() + + alpha = self.evaluate_alpha(points) + + print(points.shape, tets.shape, alpha.shape) + def alpha_to_sdf(alpha): + sdf = alpha - 0.5 + sdf = sdf[None] + return sdf + + sdf = alpha_to_sdf(alpha) + + torch.cuda.empty_cache() + # breakpoint() + # temp for points_scale + points_scale = sdf[0] + verts_list, scale_list, faces_list, _ = marching_tetrahedra(points[None], tets, sdf, points_scale[None]) + torch.cuda.empty_cache() + + end_points, end_sdf = verts_list[0] + end_scales = scale_list[0] + + faces=faces_list[0].cpu().numpy() + points = (end_points[:, 0, :] + end_points[:, 1, :]) / 2. + + left_points = end_points[:, 0, :] + right_points = end_points[:, 1, :] + left_sdf = end_sdf[:, 0, :] + right_sdf = end_sdf[:, 1, :] + left_scale = end_scales[:, 0, 0] + right_scale = end_scales[:, 1, 0] + distance = torch.norm(left_points - right_points, dim=-1) + scale = left_scale + right_scale + + n_binary_steps = 8 + for step in range(n_binary_steps): + print("binary search in step {}".format(step)) + mid_points = (left_points + right_points) / 2 + alpha = self.evaluate_alpha(mid_points) + mid_sdf = alpha_to_sdf(alpha).squeeze().unsqueeze(-1) + + ind_low = ((mid_sdf < 0) & (left_sdf < 0)) | ((mid_sdf > 0) & (left_sdf > 0)) + + left_sdf[ind_low] = mid_sdf[ind_low] + right_sdf[~ind_low] = mid_sdf[~ind_low] + left_points[ind_low.flatten()] = mid_points[ind_low.flatten()] + right_points[~ind_low.flatten()] = mid_points[~ind_low.flatten()] + + points = (left_points + right_points) / 2 + if step not in [7]: + continue + + vertex_colors=None + mesh = trimesh.Trimesh(vertices=points.cpu().numpy(), faces=faces, vertex_colors=vertex_colors, process=False) + + mesh.export(f"{cfg.result_dir}/{step}.ply") + + +def main(cfg: Config): + runner = Runner(cfg) + + if cfg.ckpt is not None: + # run eval only + ckpt = torch.load(cfg.ckpt, map_location=runner.device) + for k in runner.splats.keys(): + runner.splats[k].data = ckpt["splats"][k] + # runner.eval(step=ckpt["step"]) + # runner.render_traj(step=ckpt["step"]) + runner.extract_mesh() + else: + runner.train() + + if not cfg.disable_viewer: + print("Viewer running... Ctrl+C to exit.") + time.sleep(1000000) + + +if __name__ == "__main__": + cfg = tyro.cli(Config) + cfg.adjust_steps(cfg.steps_scaler) + main(cfg) diff --git a/examples/simple_trainer_mip_splatting.py b/examples/simple_trainer_mip_splatting.py new file mode 100644 index 000000000..57a75ce9a --- /dev/null +++ b/examples/simple_trainer_mip_splatting.py @@ -0,0 +1,1094 @@ +import json +import math +import os +import time +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Tuple + +import imageio +import numpy as np +import torch +import torch.nn.functional as F +import tqdm +import tyro +import viser +import nerfview +from datasets.colmap import Dataset, Parser +from datasets.traj import generate_interpolated_path +from torch import Tensor +from torch.utils.tensorboard import SummaryWriter +from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure +from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity +from utils import ( + AppearanceOptModule, + CameraOptModule, + knn, + normalized_quat_to_rotmat, + rgb_to_sh, + set_random_seed, +) + +from gsplat.rendering import rasterization +from gsplat import compute_3D_smoothing_filter + + +@dataclass +class Config: + # Disable viewer + disable_viewer: bool = False + # Path to the .pt file. If provide, it will skip training and render a video + ckpt: Optional[str] = None + + # Path to the Mip-NeRF 360 dataset + data_dir: str = "data/360_v2/garden" + # Downsample factor for the dataset + data_factor: int = 4 + # Directory to save results + result_dir: str = "results/garden" + # Every N images there is a test image + test_every: int = 8 + # Random crop size for training (experimental) + patch_size: Optional[int] = None + # A global scaler that applies to the scene size related parameters + global_scale: float = 1.0 + + # Port for the viewer server + port: int = 8080 + + # Batch size for training. Learning rates are scaled automatically + batch_size: int = 1 + # A global factor to scale the number of training steps + steps_scaler: float = 1.0 + + # Number of training steps + max_steps: int = 30_000 + # Steps to evaluate the model + eval_steps: List[int] = field(default_factory=lambda: [7_000, 30_000]) + # Steps to save the model + save_steps: List[int] = field(default_factory=lambda: [7_000, 30_000]) + + # Initialization strategy + init_type: str = "sfm" + # Initial number of GSs. Ignored if using sfm + init_num_pts: int = 100_000 + # Initial extent of GSs as a multiple of the camera extent. Ignored if using sfm + init_extent: float = 3.0 + # Degree of spherical harmonics + sh_degree: int = 3 + # Turn on another SH degree every this steps + sh_degree_interval: int = 1000 + # Initial opacity of GS + init_opa: float = 0.1 + # Initial scale of GS + init_scale: float = 1.0 + # Weight for SSIM loss + ssim_lambda: float = 0.2 + + # Near plane clipping distance + near_plane: float = 0.01 + # Far plane clipping distance + far_plane: float = 1e10 + + # GSs with opacity below this value will be pruned + prune_opa: float = 0.005 + # GSs with image plane gradient above this value will be split/duplicated + grow_grad2d: float = 0.0002 + # GSs with scale below this value will be duplicated. Above will be split + grow_scale3d: float = 0.01 + # GSs with scale above this value will be pruned. + prune_scale3d: float = 0.1 + + # Start refining GSs after this iteration + refine_start_iter: int = 500 + # Stop refining GSs after this iteration + refine_stop_iter: int = 15_000 + # Reset opacities every this steps + reset_every: int = 3000 + # Refine GSs every this steps + refine_every: int = 100 + + # Use packed mode for rasterization, this leads to less memory usage but slightly slower. + packed: bool = False + # Use sparse gradients for optimization. (experimental) + sparse_grad: bool = False + # Use absolute gradient for pruning. This typically requires larger --grow_grad2d, e.g., 0.0008 or 0.0006 + absgrad: bool = False + # Anti-aliasing in rasterization. Might slightly hurt quantitative metrics. + antialiased: bool = False + # kernel size for the low-pass filter in rasterization. 0.1 should be a better value since it better approximates a 2D box filter of single pixel size. + kernel_size: float = 0.3 + + # Use random background for training to discourage transparency + random_bkgd: bool = False + + # Enable camera optimization. + pose_opt: bool = False + # Learning rate for camera optimization + pose_opt_lr: float = 1e-5 + # Regularization for camera optimization as weight decay + pose_opt_reg: float = 1e-6 + # Add noise to camera extrinsics. This is only to test the camera pose optimization. + pose_noise: float = 0.0 + + # Enable appearance optimization. (experimental) + app_opt: bool = False + # Appearance embedding dimension + app_embed_dim: int = 16 + # Learning rate for appearance optimization + app_opt_lr: float = 1e-3 + # Regularization for appearance optimization as weight decay + app_opt_reg: float = 1e-6 + + # Enable depth loss. (experimental) + depth_loss: bool = False + # Weight for depth loss + depth_lambda: float = 1e-2 + + # Dump information to tensorboard every this steps + tb_every: int = 100 + # Save training images to tensorboard + tb_save_image: bool = False + + def adjust_steps(self, factor: float): + self.eval_steps = [int(i * factor) for i in self.eval_steps] + self.save_steps = [int(i * factor) for i in self.save_steps] + self.max_steps = int(self.max_steps * factor) + self.sh_degree_interval = int(self.sh_degree_interval * factor) + self.refine_start_iter = int(self.refine_start_iter * factor) + self.refine_stop_iter = int(self.refine_stop_iter * factor) + self.reset_every = int(self.reset_every * factor) + self.refine_every = int(self.refine_every * factor) + + +def create_splats_with_optimizers( + parser: Parser, + init_type: str = "sfm", + init_num_pts: int = 100_000, + init_extent: float = 3.0, + init_opacity: float = 0.1, + init_scale: float = 1.0, + scene_scale: float = 1.0, + sh_degree: int = 3, + sparse_grad: bool = False, + batch_size: int = 1, + feature_dim: Optional[int] = None, + device: str = "cuda", +) -> Tuple[torch.nn.ParameterDict, torch.optim.Optimizer]: + if init_type == "sfm": + points = torch.from_numpy(parser.points).float() + rgbs = torch.from_numpy(parser.points_rgb / 255.0).float() + elif init_type == "random": + points = init_extent * scene_scale * (torch.rand((init_num_pts, 3)) * 2 - 1) + rgbs = torch.rand((init_num_pts, 3)) + else: + raise ValueError("Please specify a correct init_type: sfm or random") + + N = points.shape[0] + # Initialize the GS size to be the average dist of the 3 nearest neighbors + dist2_avg = (knn(points, 4)[:, 1:] ** 2).mean(dim=-1) # [N,] + dist_avg = torch.sqrt(dist2_avg) + scales = torch.log(dist_avg * init_scale).unsqueeze(-1).repeat(1, 3) # [N, 3] + quats = torch.rand((N, 4)) # [N, 4] + opacities = torch.logit(torch.full((N,), init_opacity)) # [N,] + + params = [ + # name, value, lr + ("means3d", torch.nn.Parameter(points), 1.6e-4 * scene_scale), + ("scales", torch.nn.Parameter(scales), 5e-3), + ("quats", torch.nn.Parameter(quats), 1e-3), + ("opacities", torch.nn.Parameter(opacities), 5e-2), + # 3D smoothing filter, setting lr to 0.0 to disable optimization + ("filters", torch.nn.Parameter(torch.ones_like(opacities)), 0.0), + ] + + if feature_dim is None: + # color is SH coefficients. + colors = torch.zeros((N, (sh_degree + 1) ** 2, 3)) # [N, K, 3] + colors[:, 0, :] = rgb_to_sh(rgbs) + params.append(("sh0", torch.nn.Parameter(colors[:, :1, :]), 2.5e-3)) + params.append(("shN", torch.nn.Parameter(colors[:, 1:, :]), 2.5e-3 / 20)) + else: + # features will be used for appearance and view-dependent shading + features = torch.rand(N, feature_dim) # [N, feature_dim] + params.append(("features", torch.nn.Parameter(features), 2.5e-3)) + colors = torch.logit(rgbs) # [N, 3] + params.append(("colors", torch.nn.Parameter(colors), 2.5e-3)) + + splats = torch.nn.ParameterDict({n: v for n, v, _ in params}).to(device) + # Scale learning rate based on batch size, reference: + # https://www.cs.princeton.edu/~smalladi/blog/2024/01/22/SDEs-ScalingRules/ + # Note that this would not make the training exactly equivalent, see + # https://arxiv.org/pdf/2402.18824v1 + optimizers = [ + (torch.optim.SparseAdam if sparse_grad else torch.optim.Adam)( + [{"params": splats[name], "lr": lr * math.sqrt(batch_size), "name": name}], + eps=1e-15 / math.sqrt(batch_size), + betas=(1 - batch_size * (1 - 0.9), 1 - batch_size * (1 - 0.999)), + ) + for name, _, lr in params + ] + return splats, optimizers + + +class Runner: + """Engine for training and testing.""" + + def __init__(self, cfg: Config) -> None: + set_random_seed(42) + + self.cfg = cfg + self.device = "cuda" + + # Where to dump results. + os.makedirs(cfg.result_dir, exist_ok=True) + + # Setup output directories. + self.ckpt_dir = f"{cfg.result_dir}/ckpts" + os.makedirs(self.ckpt_dir, exist_ok=True) + self.stats_dir = f"{cfg.result_dir}/stats" + os.makedirs(self.stats_dir, exist_ok=True) + self.render_dir = f"{cfg.result_dir}/renders" + os.makedirs(self.render_dir, exist_ok=True) + + # Tensorboard + self.writer = SummaryWriter(log_dir=f"{cfg.result_dir}/tb") + + # Load data: Training data should contain initial points and colors. + self.parser = Parser( + data_dir=cfg.data_dir, + factor=cfg.data_factor, + normalize=True, + test_every=cfg.test_every, + ) + self.trainset = Dataset( + self.parser, + split="train", + patch_size=cfg.patch_size, + load_depths=cfg.depth_loss, + ) + self.valset = Dataset(self.parser, split="val") + self.scene_scale = self.parser.scene_scale * 1.1 * cfg.global_scale + print("Scene scale:", self.scene_scale) + + # Model + feature_dim = 32 if cfg.app_opt else None + self.splats, self.optimizers = create_splats_with_optimizers( + self.parser, + init_type=cfg.init_type, + init_num_pts=cfg.init_num_pts, + init_extent=cfg.init_extent, + init_opacity=cfg.init_opa, + init_scale=cfg.init_scale, + scene_scale=self.scene_scale, + sh_degree=cfg.sh_degree, + sparse_grad=cfg.sparse_grad, + batch_size=cfg.batch_size, + feature_dim=feature_dim, + device=self.device, + ) + print("Model initialized. Number of GS:", len(self.splats["means3d"])) + + self.pose_optimizers = [] + if cfg.pose_opt: + self.pose_adjust = CameraOptModule(len(self.trainset)).to(self.device) + self.pose_adjust.zero_init() + self.pose_optimizers = [ + torch.optim.Adam( + self.pose_adjust.parameters(), + lr=cfg.pose_opt_lr * math.sqrt(cfg.batch_size), + weight_decay=cfg.pose_opt_reg, + ) + ] + + if cfg.pose_noise > 0.0: + self.pose_perturb = CameraOptModule(len(self.trainset)).to(self.device) + self.pose_perturb.random_init(cfg.pose_noise) + + self.app_optimizers = [] + if cfg.app_opt: + self.app_module = AppearanceOptModule( + len(self.trainset), feature_dim, cfg.app_embed_dim, cfg.sh_degree + ).to(self.device) + # initialize the last layer to be zero so that the initial output is zero. + torch.nn.init.zeros_(self.app_module.color_head[-1].weight) + torch.nn.init.zeros_(self.app_module.color_head[-1].bias) + self.app_optimizers = [ + torch.optim.Adam( + self.app_module.embeds.parameters(), + lr=cfg.app_opt_lr * math.sqrt(cfg.batch_size) * 10.0, + weight_decay=cfg.app_opt_reg, + ), + torch.optim.Adam( + self.app_module.color_head.parameters(), + lr=cfg.app_opt_lr * math.sqrt(cfg.batch_size), + ), + ] + + # Losses & Metrics. + self.ssim = StructuralSimilarityIndexMeasure(data_range=1.0).to(self.device) + self.psnr = PeakSignalNoiseRatio(data_range=1.0).to(self.device) + self.lpips = LearnedPerceptualImagePatchSimilarity(normalize=True).to( + self.device + ) + + # Viewer + if not self.cfg.disable_viewer: + self.server = viser.ViserServer(port=cfg.port, verbose=False) + self.viewer = nerfview.Viewer( + server=self.server, + render_fn=self._viewer_render_fn, + mode="training", + ) + + # Running stats for prunning & growing. + n_gauss = len(self.splats["means3d"]) + self.running_stats = { + "grad2d": torch.zeros(n_gauss, device=self.device), # norm of the gradient + "count": torch.zeros(n_gauss, device=self.device, dtype=torch.int), + } + + def rasterize_splats( + self, + camtoworlds: Tensor, + Ks: Tensor, + width: int, + height: int, + **kwargs, + ) -> Tuple[Tensor, Tensor, Dict]: + means = self.splats["means3d"] # [N, 3] + # quats = F.normalize(self.splats["quats"], dim=-1) # [N, 4] + # rasterization does normalization internally + quats = self.splats["quats"] # [N, 4] + scales = torch.exp(self.splats["scales"]) # [N, 3] + opacities = torch.sigmoid(self.splats["opacities"]) # [N,] + filters = self.splats["filters"] # [N,] + + # apply 3D smoothing filter to scales and opacities + scales_square = torch.square(scales) # [N, 3] + det1 = scales_square.prod(dim=1) # [N, ] + + scales_after_square = scales_square + torch.square(filters)[:, None] # [N, 1] + det2 = scales_after_square.prod(dim=1) # [N,] + coef = torch.sqrt(det1 / det2 + 1e-7) # [N,] + opacities = opacities * coef + + scales = torch.square(scales) + torch.square(filters)[:, None] # [N, 3] + scales = torch.sqrt(scales) + + image_ids = kwargs.pop("image_ids", None) + if self.cfg.app_opt: + colors = self.app_module( + features=self.splats["features"], + embed_ids=image_ids, + dirs=means[None, :, :] - camtoworlds[:, None, :3, 3], + sh_degree=kwargs.pop("sh_degree", self.cfg.sh_degree), + ) + colors = colors + self.splats["colors"] + colors = torch.sigmoid(colors) + else: + colors = torch.cat([self.splats["sh0"], self.splats["shN"]], 1) # [N, K, 3] + + rasterize_mode = "antialiased" if self.cfg.antialiased else "classic" + render_colors, render_alphas, info = rasterization( + means=means, + quats=quats, + scales=scales, + opacities=opacities, + colors=colors, + viewmats=torch.linalg.inv(camtoworlds), # [C, 4, 4] + Ks=Ks, # [C, 3, 3] + width=width, + height=height, + packed=self.cfg.packed, + absgrad=self.cfg.absgrad, + sparse_grad=self.cfg.sparse_grad, + rasterize_mode=rasterize_mode, + eps2d=self.cfg.kernel_size, + **kwargs, + ) + return render_colors, render_alphas, info + + def train(self): + cfg = self.cfg + device = self.device + + # Dump cfg. + with open(f"{cfg.result_dir}/cfg.json", "w") as f: + json.dump(vars(cfg), f) + + max_steps = cfg.max_steps + init_step = 0 + + schedulers = [ + # means3d has a learning rate schedule, that end at 0.01 of the initial value + torch.optim.lr_scheduler.ExponentialLR( + self.optimizers[0], gamma=0.01 ** (1.0 / max_steps) + ), + ] + if cfg.pose_opt: + # pose optimization has a learning rate schedule + schedulers.append( + torch.optim.lr_scheduler.ExponentialLR( + self.pose_optimizers[0], gamma=0.01 ** (1.0 / max_steps) + ) + ) + + trainloader = torch.utils.data.DataLoader( + self.trainset, + batch_size=cfg.batch_size, + shuffle=True, + num_workers=4, + persistent_workers=True, + pin_memory=True, + ) + trainloader_iter = iter(trainloader) + + # determine the 3D smoothing filter before training + self.compute_3D_smoothing_filter() + + # Training loop. + global_tic = time.time() + pbar = tqdm.tqdm(range(init_step, max_steps)) + for step in pbar: + if not cfg.disable_viewer: + while self.viewer.state.status == "paused": + time.sleep(0.01) + self.viewer.lock.acquire() + tic = time.time() + + try: + data = next(trainloader_iter) + except StopIteration: + trainloader_iter = iter(trainloader) + data = next(trainloader_iter) + + camtoworlds = camtoworlds_gt = data["camtoworld"].to(device) # [1, 4, 4] + Ks = data["K"].to(device) # [1, 3, 3] + pixels = data["image"].to(device) / 255.0 # [1, H, W, 3] + num_train_rays_per_step = ( + pixels.shape[0] * pixels.shape[1] * pixels.shape[2] + ) + image_ids = data["image_id"].to(device) + if cfg.depth_loss: + points = data["points"].to(device) # [1, M, 2] + depths_gt = data["depths"].to(device) # [1, M] + + height, width = pixels.shape[1:3] + + if cfg.pose_noise: + camtoworlds = self.pose_perturb(camtoworlds, image_ids) + + if cfg.pose_opt: + camtoworlds = self.pose_adjust(camtoworlds, image_ids) + + # sh schedule + sh_degree_to_use = min(step // cfg.sh_degree_interval, cfg.sh_degree) + + # forward + renders, alphas, info = self.rasterize_splats( + camtoworlds=camtoworlds, + Ks=Ks, + width=width, + height=height, + sh_degree=sh_degree_to_use, + near_plane=cfg.near_plane, + far_plane=cfg.far_plane, + image_ids=image_ids, + render_mode="RGB+ED" if cfg.depth_loss else "RGB", + ) + if renders.shape[-1] == 4: + colors, depths = renders[..., 0:3], renders[..., 3:4] + else: + colors, depths = renders, None + + if cfg.random_bkgd: + bkgd = torch.rand(1, 3, device=device) + colors = colors + bkgd * (1.0 - alphas) + + info["means2d"].retain_grad() # used for running stats + + # loss + l1loss = F.l1_loss(colors, pixels) + ssimloss = 1.0 - self.ssim( + pixels.permute(0, 3, 1, 2), colors.permute(0, 3, 1, 2) + ) + loss = l1loss * (1.0 - cfg.ssim_lambda) + ssimloss * cfg.ssim_lambda + if cfg.depth_loss: + # query depths from depth map + points = torch.stack( + [ + points[:, :, 0] / (width - 1) * 2 - 1, + points[:, :, 1] / (height - 1) * 2 - 1, + ], + dim=-1, + ) # normalize to [-1, 1] + grid = points.unsqueeze(2) # [1, M, 1, 2] + depths = F.grid_sample( + depths.permute(0, 3, 1, 2), grid, align_corners=True + ) # [1, 1, M, 1] + depths = depths.squeeze(3).squeeze(1) # [1, M] + # calculate loss in disparity space + disp = torch.where(depths > 0.0, 1.0 / depths, torch.zeros_like(depths)) + disp_gt = 1.0 / depths_gt # [1, M] + depthloss = F.l1_loss(disp, disp_gt) * self.scene_scale + loss += depthloss * cfg.depth_lambda + + loss.backward() + + desc = f"loss={loss.item():.3f}| " f"sh degree={sh_degree_to_use}| " + if cfg.depth_loss: + desc += f"depth loss={depthloss.item():.6f}| " + if cfg.pose_opt and cfg.pose_noise: + # monitor the pose error if we inject noise + pose_err = F.l1_loss(camtoworlds_gt, camtoworlds) + desc += f"pose err={pose_err.item():.6f}| " + pbar.set_description(desc) + + if cfg.tb_every > 0 and step % cfg.tb_every == 0: + mem = torch.cuda.max_memory_allocated() / 1024**3 + self.writer.add_scalar("train/loss", loss.item(), step) + self.writer.add_scalar("train/l1loss", l1loss.item(), step) + self.writer.add_scalar("train/ssimloss", ssimloss.item(), step) + self.writer.add_scalar( + "train/num_GS", len(self.splats["means3d"]), step + ) + self.writer.add_scalar("train/mem", mem, step) + if cfg.depth_loss: + self.writer.add_scalar("train/depthloss", depthloss.item(), step) + if cfg.tb_save_image: + canvas = torch.cat([pixels, colors], dim=2).detach().cpu().numpy() + canvas = canvas.reshape(-1, *canvas.shape[2:]) + self.writer.add_image("train/render", canvas, step) + self.writer.flush() + + # update running stats for prunning & growing + if step < cfg.refine_stop_iter: + self.update_running_stats(info) + + if step > cfg.refine_start_iter and step % cfg.refine_every == 0: + grads = self.running_stats["grad2d"] / self.running_stats[ + "count" + ].clamp_min(1) + + # grow GSs + is_grad_high = grads >= cfg.grow_grad2d + is_small = ( + torch.exp(self.splats["scales"]).max(dim=-1).values + <= cfg.grow_scale3d * self.scene_scale + ) + is_dupli = is_grad_high & is_small + n_dupli = is_dupli.sum().item() + self.refine_duplicate(is_dupli) + + is_split = is_grad_high & ~is_small + is_split = torch.cat( + [ + is_split, + # new GSs added by duplication will not be split + torch.zeros(n_dupli, device=device, dtype=torch.bool), + ] + ) + n_split = is_split.sum().item() + self.refine_split(is_split) + print( + f"Step {step}: {n_dupli} GSs duplicated, {n_split} GSs split. " + f"Now having {len(self.splats['means3d'])} GSs." + ) + + # prune GSs + is_prune = torch.sigmoid(self.splats["opacities"]) < cfg.prune_opa + if step > cfg.reset_every: + # The official code also implements sreen-size pruning but + # it's actually not being used due to a bug: + # https://github.com/graphdeco-inria/gaussian-splatting/issues/123 + is_too_big = ( + torch.exp(self.splats["scales"]).max(dim=-1).values + > cfg.prune_scale3d * self.scene_scale + ) + is_prune = is_prune | is_too_big + n_prune = is_prune.sum().item() + self.refine_keep(~is_prune) + print( + f"Step {step}: {n_prune} GSs pruned. " + f"Now having {len(self.splats['means3d'])} GSs." + ) + self.compute_3D_smoothing_filter() + + # reset running stats + self.running_stats["grad2d"].zero_() + self.running_stats["count"].zero_() + + if step % cfg.reset_every == 0: + self.reset_opa(cfg.prune_opa * 2.0) + + # Turn Gradients into Sparse Tensor before running optimizer + if cfg.sparse_grad: + assert cfg.packed, "Sparse gradients only work with packed mode." + gaussian_ids = info["gaussian_ids"] + for k in self.splats.keys(): + grad = self.splats[k].grad + if grad is None or grad.is_sparse: + continue + self.splats[k].grad = torch.sparse_coo_tensor( + indices=gaussian_ids[None], # [1, nnz] + values=grad[gaussian_ids], # [nnz, ...] + size=self.splats[k].size(), # [N, ...] + is_coalesced=len(Ks) == 1, + ) + + # optimize + for optimizer in self.optimizers: + optimizer.step() + optimizer.zero_grad(set_to_none=True) + for optimizer in self.pose_optimizers: + optimizer.step() + optimizer.zero_grad(set_to_none=True) + for optimizer in self.app_optimizers: + optimizer.step() + optimizer.zero_grad(set_to_none=True) + for scheduler in schedulers: + scheduler.step() + + # save checkpoint + if step in [i - 1 for i in cfg.save_steps] or step == max_steps - 1: + mem = torch.cuda.max_memory_allocated() / 1024**3 + stats = { + "mem": mem, + "ellipse_time": time.time() - global_tic, + "num_GS": len(self.splats["means3d"]), + } + print("Step: ", step, stats) + with open(f"{self.stats_dir}/train_step{step:04d}.json", "w") as f: + json.dump(stats, f) + torch.save( + { + "step": step, + "splats": self.splats.state_dict(), + }, + f"{self.ckpt_dir}/ckpt_{step}.pt", + ) + + # eval the full set + if step in [i - 1 for i in cfg.eval_steps] or step == max_steps - 1: + self.eval(step) + self.render_traj(step) + + if not cfg.disable_viewer: + self.viewer.lock.release() + num_train_steps_per_sec = 1.0 / (time.time() - tic) + num_train_rays_per_sec = ( + num_train_rays_per_step * num_train_steps_per_sec + ) + # Update the viewer state. + self.viewer.state.num_train_rays_per_sec = num_train_rays_per_sec + # Update the scene. + self.viewer.update(step, num_train_rays_per_step) + + @torch.no_grad() + def compute_3D_smoothing_filter(self): + cfg = self.cfg + device = self.device + xyz = self.splats["means3d"] + print("xyz", xyz.shape, xyz.device) + + worldtocams = ( + torch.from_numpy(self.trainset.parser.worldtocams).float().to(device) + ) + # TODO, currently use K, H, W of the first image + data = self.trainset[0] + K = data["K"].to(device) # [3, 3] + height, width = data["image"].shape[:2] + Ks = torch.stack([K] * len(worldtocams), dim=0) # [C, 3, 3] + filter_3D = compute_3D_smoothing_filter( + xyz, worldtocams, Ks, width, height, cfg.near_plane + ) + filter_3D = filter_3D * (0.2**0.5) + self.splats["filters"] = torch.nn.Parameter(filter_3D) + + @torch.no_grad() + def compute_3D_smoothing_filter_torch(self): + print("Computing 3D filter") + cfg = self.cfg + device = self.device + xyz = self.splats["means3d"] + print("xyz", xyz.shape, xyz.device) + + distance = torch.ones((xyz.shape[0]), device=xyz.device) * 100000.0 + valid_points = torch.zeros((xyz.shape[0]), device=xyz.device, dtype=torch.bool) + focal_length = 0.0 + + for data in self.trainset: + worldtocam = data["worldtocam"].to(device) # [4, 4] + K = data["K"].to(device) # [3, 3] + height, width = data["image"].shape[:2] + R = worldtocam[:3, :3] + T = worldtocam[:3, 3] + + xyz_cam = xyz @ R.transpose(1, 0) + T[None, :] + + # project to screen space + valid_depth = xyz_cam[:, 2] > cfg.near_plane + + x, y, z = xyz_cam[:, 0], xyz_cam[:, 1], xyz_cam[:, 2] + z = torch.clamp(z, min=0.001) + + x = x / z * K[0, 0] + K[0, 2] + y = y / z * K[1, 1] + K[1, 2] + + # use similar tangent space filtering as in 3DGS, + # TODO check gsplat's implementation + in_screen = torch.logical_and( + torch.logical_and(x >= -0.15 * width, x <= width * 1.15), + torch.logical_and(y >= -0.15 * height, y <= 1.15 * height), + ) + valid = torch.logical_and(valid_depth, in_screen) + + distance[valid] = torch.min(distance[valid], z[valid]) + valid_points = torch.logical_or(valid_points, valid) + if focal_length < K[0, 0]: + focal_length = K[0, 0] + + distance[~valid_points] = distance[valid_points].max() + + filter_3D = distance / focal_length * (0.2**0.5) + self.splats["filters"] = torch.nn.Parameter(filter_3D) + + @torch.no_grad() + def update_running_stats(self, info: Dict): + """Update running stats.""" + cfg = self.cfg + + # normalize grads to [-1, 1] screen space + if cfg.absgrad: + grads = info["means2d"].absgrad.clone() + else: + grads = info["means2d"].grad.clone() + grads[..., 0] *= info["width"] / 2.0 * cfg.batch_size + grads[..., 1] *= info["height"] / 2.0 * cfg.batch_size + if cfg.packed: + # grads is [nnz, 2] + gs_ids = info["gaussian_ids"] # [nnz] or None + self.running_stats["grad2d"].index_add_(0, gs_ids, grads.norm(dim=-1)) + self.running_stats["count"].index_add_( + 0, gs_ids, torch.ones_like(gs_ids).int() + ) + else: + # grads is [C, N, 2] + sel = info["radii"] > 0.0 # [C, N] + gs_ids = torch.where(sel)[1] # [nnz] + self.running_stats["grad2d"].index_add_(0, gs_ids, grads[sel].norm(dim=-1)) + self.running_stats["count"].index_add_( + 0, gs_ids, torch.ones_like(gs_ids).int() + ) + + @torch.no_grad() + def reset_opa(self, value: float = 0.01): + """Utility function to reset opacities.""" + # opacities = torch.clamp( + # self.splats["opacities"], max=torch.logit(torch.tensor(value)).item() + # ) + + scales = torch.exp(self.splats["scales"]) # [N, 3] + opacities = torch.sigmoid(self.splats["opacities"]) # [N,] + filters = self.splats["filters"] # [N,] + + # apply 3D smoothing filter to scales and opacities + print("apply 3D smoothing filter in reset opacities") + scales_square = torch.square(scales) # [N, 3] + det1 = scales_square.prod(dim=1) # [N, ] + + scales_after_square = scales_square + torch.square(filters)[:, None] # [N, 1] + det2 = scales_after_square.prod(dim=1) # [N,] + coef = torch.sqrt(det1 / det2 + 1e-7) # [N,] + opacities = opacities * coef + + opacities_reset = torch.min(opacities, torch.ones_like(opacities) * value) + opacities_reset = opacities_reset / (coef + 1e-7) + opacities = torch.logit(opacities_reset) + + for optimizer in self.optimizers: + for i, param_group in enumerate(optimizer.param_groups): + if param_group["name"] != "opacities": + continue + p = param_group["params"][0] + p_state = optimizer.state[p] + del optimizer.state[p] + for key in p_state.keys(): + if key != "step": + p_state[key] = torch.zeros_like(p_state[key]) + p_new = torch.nn.Parameter(opacities) + optimizer.param_groups[i]["params"] = [p_new] + optimizer.state[p_new] = p_state + self.splats[param_group["name"]] = p_new + torch.cuda.empty_cache() + + @torch.no_grad() + def refine_split(self, mask: Tensor): + """Utility function to grow GSs.""" + device = self.device + + sel = torch.where(mask)[0] + rest = torch.where(~mask)[0] + + scales = torch.exp(self.splats["scales"][sel]) # [N, 3] + quats = F.normalize(self.splats["quats"][sel], dim=-1) # [N, 4] + rotmats = normalized_quat_to_rotmat(quats) # [N, 3, 3] + samples = torch.einsum( + "nij,nj,bnj->bni", + rotmats, + scales, + torch.randn(2, len(scales), 3, device=device), + ) # [2, N, 3] + + for optimizer in self.optimizers: + for i, param_group in enumerate(optimizer.param_groups): + p = param_group["params"][0] + name = param_group["name"] + # create new params + if name == "means3d": + p_split = (p[sel] + samples).reshape(-1, 3) # [2N, 3] + elif name == "scales": + p_split = torch.log(scales / 1.6).repeat(2, 1) # [2N, 3] + else: + repeats = [2] + [1] * (p.dim() - 1) + p_split = p[sel].repeat(repeats) + p_new = torch.cat([p[rest], p_split]) + p_new = torch.nn.Parameter(p_new) + # update optimizer + p_state = optimizer.state[p] + del optimizer.state[p] + for key in p_state.keys(): + if key == "step": + continue + v = p_state[key] + # new params are assigned with zero optimizer states + # (worth investigating it) + v_split = torch.zeros((2 * len(sel), *v.shape[1:]), device=device) + p_state[key] = torch.cat([v[rest], v_split]) + optimizer.param_groups[i]["params"] = [p_new] + optimizer.state[p_new] = p_state + self.splats[name] = p_new + for k, v in self.running_stats.items(): + if v is None: + continue + repeats = [2] + [1] * (v.dim() - 1) + v_new = v[sel].repeat(repeats) + self.running_stats[k] = torch.cat((v[rest], v_new)) + torch.cuda.empty_cache() + + @torch.no_grad() + def refine_duplicate(self, mask: Tensor): + """Unility function to duplicate GSs.""" + sel = torch.where(mask)[0] + for optimizer in self.optimizers: + for i, param_group in enumerate(optimizer.param_groups): + p = param_group["params"][0] + name = param_group["name"] + p_state = optimizer.state[p] + del optimizer.state[p] + for key in p_state.keys(): + if key != "step": + # new params are assigned with zero optimizer states + # (worth investigating it as it will lead to a lot more GS.) + v = p_state[key] + v_new = torch.zeros( + (len(sel), *v.shape[1:]), device=self.device + ) + # v_new = v[sel] + p_state[key] = torch.cat([v, v_new]) + p_new = torch.nn.Parameter(torch.cat([p, p[sel]])) + optimizer.param_groups[i]["params"] = [p_new] + optimizer.state[p_new] = p_state + self.splats[name] = p_new + for k, v in self.running_stats.items(): + self.running_stats[k] = torch.cat((v, v[sel])) + torch.cuda.empty_cache() + + @torch.no_grad() + def refine_keep(self, mask: Tensor): + """Unility function to prune GSs.""" + sel = torch.where(mask)[0] + for optimizer in self.optimizers: + for i, param_group in enumerate(optimizer.param_groups): + p = param_group["params"][0] + name = param_group["name"] + p_state = optimizer.state[p] + del optimizer.state[p] + for key in p_state.keys(): + if key != "step": + p_state[key] = p_state[key][sel] + p_new = torch.nn.Parameter(p[sel]) + optimizer.param_groups[i]["params"] = [p_new] + optimizer.state[p_new] = p_state + self.splats[name] = p_new + for k, v in self.running_stats.items(): + self.running_stats[k] = v[sel] + torch.cuda.empty_cache() + + @torch.no_grad() + def eval(self, step: int): + """Entry for evaluation.""" + print("Running evaluation...") + cfg = self.cfg + device = self.device + + valloader = torch.utils.data.DataLoader( + self.valset, batch_size=1, shuffle=False, num_workers=1 + ) + ellipse_time = 0 + metrics = {"psnr": [], "ssim": [], "lpips": []} + for i, data in enumerate(valloader): + camtoworlds = data["camtoworld"].to(device) + Ks = data["K"].to(device) + pixels = data["image"].to(device) / 255.0 + height, width = pixels.shape[1:3] + + torch.cuda.synchronize() + tic = time.time() + colors, _, _ = self.rasterize_splats( + camtoworlds=camtoworlds, + Ks=Ks, + width=width, + height=height, + sh_degree=cfg.sh_degree, + near_plane=cfg.near_plane, + far_plane=cfg.far_plane, + ) # [1, H, W, 3] + colors = torch.clamp(colors, 0.0, 1.0) + torch.cuda.synchronize() + ellipse_time += time.time() - tic + + # write images + canvas = torch.cat([pixels, colors], dim=2).squeeze(0).cpu().numpy() + imageio.imwrite( + f"{self.render_dir}/val_{i:04d}.png", (canvas * 255).astype(np.uint8) + ) + + pixels = pixels.permute(0, 3, 1, 2) # [1, 3, H, W] + colors = colors.permute(0, 3, 1, 2) # [1, 3, H, W] + metrics["psnr"].append(self.psnr(colors, pixels)) + metrics["ssim"].append(self.ssim(colors, pixels)) + metrics["lpips"].append(self.lpips(colors, pixels)) + + ellipse_time /= len(valloader) + + psnr = torch.stack(metrics["psnr"]).mean() + ssim = torch.stack(metrics["ssim"]).mean() + lpips = torch.stack(metrics["lpips"]).mean() + print( + f"PSNR: {psnr.item():.3f}, SSIM: {ssim.item():.4f}, LPIPS: {lpips.item():.3f} " + f"Time: {ellipse_time:.3f}s/image " + f"Number of GS: {len(self.splats['means3d'])}" + ) + # save stats as json + stats = { + "psnr": psnr.item(), + "ssim": ssim.item(), + "lpips": lpips.item(), + "ellipse_time": ellipse_time, + "num_GS": len(self.splats["means3d"]), + } + with open(f"{self.stats_dir}/val_step{step:04d}.json", "w") as f: + json.dump(stats, f) + # save stats to tensorboard + for k, v in stats.items(): + self.writer.add_scalar(f"val/{k}", v, step) + self.writer.flush() + + @torch.no_grad() + def render_traj(self, step: int): + """Entry for trajectory rendering.""" + print("Running trajectory rendering...") + cfg = self.cfg + device = self.device + + camtoworlds = self.parser.camtoworlds[5:-5] + camtoworlds = generate_interpolated_path(camtoworlds, 1) # [N, 3, 4] + camtoworlds = np.concatenate( + [ + camtoworlds, + np.repeat(np.array([[[0.0, 0.0, 0.0, 1.0]]]), len(camtoworlds), axis=0), + ], + axis=1, + ) # [N, 4, 4] + + camtoworlds = torch.from_numpy(camtoworlds).float().to(device) + K = torch.from_numpy(list(self.parser.Ks_dict.values())[0]).float().to(device) + width, height = list(self.parser.imsize_dict.values())[0] + + canvas_all = [] + for i in tqdm.trange(len(camtoworlds), desc="Rendering trajectory"): + renders, _, _ = self.rasterize_splats( + camtoworlds=camtoworlds[i : i + 1], + Ks=K[None], + width=width, + height=height, + sh_degree=cfg.sh_degree, + near_plane=cfg.near_plane, + far_plane=cfg.far_plane, + render_mode="RGB+ED", + ) # [1, H, W, 4] + colors = torch.clamp(renders[0, ..., 0:3], 0.0, 1.0) # [H, W, 3] + depths = renders[0, ..., 3:4] # [H, W, 1] + depths = (depths - depths.min()) / (depths.max() - depths.min()) + + # write images + canvas = torch.cat( + [colors, depths.repeat(1, 1, 3)], dim=0 if width > height else 1 + ) + canvas = (canvas.cpu().numpy() * 255).astype(np.uint8) + canvas_all.append(canvas) + + # save to video + video_dir = f"{cfg.result_dir}/videos" + os.makedirs(video_dir, exist_ok=True) + writer = imageio.get_writer(f"{video_dir}/traj_{step}.mp4", fps=30) + for canvas in canvas_all: + writer.append_data(canvas) + writer.close() + print(f"Video saved to {video_dir}/traj_{step}.mp4") + + @torch.no_grad() + def _viewer_render_fn( + self, camera_state: nerfview.CameraState, img_wh: Tuple[int, int] + ): + """Callable function for the viewer.""" + W, H = img_wh + c2w = camera_state.c2w + K = camera_state.get_K(img_wh) + c2w = torch.from_numpy(c2w).float().to(self.device) + K = torch.from_numpy(K).float().to(self.device) + + render_colors, _, _ = self.rasterize_splats( + camtoworlds=c2w[None], + Ks=K[None], + width=W, + height=H, + sh_degree=self.cfg.sh_degree, # active all SH degrees + radius_clip=3.0, # skip GSs that have small image radius (in pixels) + ) # [1, H, W, 3] + return render_colors[0].cpu().numpy() + + +def main(cfg: Config): + runner = Runner(cfg) + + if cfg.ckpt is not None: + # run eval only + ckpt = torch.load(cfg.ckpt, map_location=runner.device) + for k in runner.splats.keys(): + runner.splats[k].data = ckpt["splats"][k] + runner.eval(step=ckpt["step"]) + # runner.render_traj(step=ckpt["step"]) + else: + runner.train() + + if not cfg.disable_viewer: + print("Viewer running... Ctrl+C to exit.") + time.sleep(1000000) + + +if __name__ == "__main__": + cfg = tyro.cli(Config) + cfg.adjust_steps(cfg.steps_scaler) + main(cfg) diff --git a/examples/utils.py b/examples/utils.py index 89fdaa503..ab6955bf4 100644 --- a/examples/utils.py +++ b/examples/utils.py @@ -5,7 +5,7 @@ from sklearn.neighbors import NearestNeighbors from torch import Tensor import torch.nn.functional as F - +import trimesh class CameraOptModule(torch.nn.Module): """Camera pose optimization module.""" @@ -172,3 +172,80 @@ def set_random_seed(seed: int): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) + + +def depths_to_points(K: Tensor, depthmap: Tensor) -> Tensor: + """ + K: camera intrinsic + depth: depthmap + """ + H, W = depthmap.shape + grid_x, grid_y = torch.meshgrid( + torch.arange(W, device="cuda").float() + 0.5, + torch.arange(H, device="cuda").float() + 0.5, + indexing="xy", + ) + points = torch.stack([grid_x, grid_y, torch.ones_like(grid_x)], dim=-1).reshape( + -1, 3 + ) + rays_d = points @ K.inverse().T + points = depthmap.reshape(-1, 1) * rays_d + return points + + +def depth_to_normal(K: Tensor, depth: Tensor) -> Tensor: + """ + K: camera intrinsic + depth: depthmap + """ + points = depths_to_points(K, depth).reshape(*depth.shape, 3) + output = torch.zeros_like(points) + dx = torch.cat([points[2:, 1:-1] - points[:-2, 1:-1]], dim=0) + dy = torch.cat([points[1:-1, 2:] - points[1:-1, :-2]], dim=1) + normal_map = torch.nn.functional.normalize(torch.cross(dx, dy, dim=-1), dim=-1) + output[1:-1, 1:-1, :] = normal_map + return output + + +def create_tetrahedra_points(quats, xyz, scale, opacity=None, opacity_threshold=0.1): + device = xyz.device + + M = trimesh.creation.box(extents=[2.0, 2.0, 2.0]) + + quats = F.normalize(quats, p=2, dim=-1) + w, x, y, z = torch.unbind(quats, dim=-1) + R = torch.stack( + [ + 1 - 2 * (y**2 + z**2), + 2 * (x * y - w * z), + 2 * (x * z + w * y), + 2 * (x * y + w * z), + 1 - 2 * (x**2 + z**2), + 2 * (y * z - w * x), + 2 * (x * z - w * y), + 2 * (y * z + w * x), + 1 - 2 * (x**2 + y**2), + ], + dim=-1, + ) + + rots = R.reshape(quats.shape[:-1] + (3, 3)) # (..., 3, 3) + scale = scale * 3.0 # scale up the box at 3 sigma + + # filter points with small opacity + if opacity is not None: + mask = (opacity > opacity_threshold).squeeze(-1) + xyz = xyz[mask] + scale = scale[mask] + rots = rots[mask] + + vertices = M.vertices.T + vertices = torch.from_numpy(vertices).float().to(device).unsqueeze(0).repeat(xyz.shape[0], 1, 1) + # scale vertices first + vertices = vertices * scale.unsqueeze(-1) + vertices = torch.bmm(rots, vertices).squeeze(-1) + xyz.unsqueeze(-1) + vertices = vertices.permute(0, 2, 1).reshape(-1, 3).contiguous() + # concat center points + vertices = torch.cat([vertices, xyz], dim=0) + + return vertices \ No newline at end of file diff --git a/gsplat/__init__.py b/gsplat/__init__.py index cf238ec3b..f2c0a79e9 100644 --- a/gsplat/__init__.py +++ b/gsplat/__init__.py @@ -9,11 +9,18 @@ quat_scale_to_covar_preci, rasterize_to_indices_in_range, rasterize_to_pixels, + raytracing_to_pixels, + integrate_to_points, spherical_harmonics, + view_to_gaussians, world_to_cam, + compute_3D_smoothing_filter, + project_points, ) from .rendering import ( rasterization, + integration, + raytracing, rasterization_inria_wrapper, rasterization_legacy_wrapper, ) @@ -104,10 +111,14 @@ def get_tile_bin_edges(*args, **kwargs): "spherical_harmonics", "isect_offset_encode", "isect_tiles", + "points_isect_tiles", "persp_proj", "fully_fused_projection", "quat_scale_to_covar_preci", "rasterize_to_pixels", + "raytracing_to_pixels", + "integrate_to_points", + "view_to_gaussian", "world_to_cam", "accumulate", "rasterize_to_indices_in_range", @@ -120,4 +131,6 @@ def get_tile_bin_edges(*args, **kwargs): "compute_cumulative_intersects", "compute_cov2d_bounds", "get_tile_bin_edges", + "compute_3D_smoothing_filter", + "project_points", ] diff --git a/gsplat/cuda/_wrapper.py b/gsplat/cuda/_wrapper.py index 9043181ce..07247e21c 100644 --- a/gsplat/cuda/_wrapper.py +++ b/gsplat/cuda/_wrapper.py @@ -348,6 +348,80 @@ def isect_tiles( return tiles_per_gauss, isect_ids, flatten_ids +@torch.no_grad() +def points_isect_tiles( + means2d: Tensor, # [C, N, 2] or [nnz, 2] + radii: Tensor, # [C, N] or [nnz] + depths: Tensor, # [C, N] or [nnz] + tile_size: int, + tile_width: int, + tile_height: int, + sort: bool = True, + packed: bool = False, + n_cameras: Optional[int] = None, + camera_ids: Optional[Tensor] = None, + gaussian_ids: Optional[Tensor] = None, +) -> Tuple[Tensor, Tensor, Tensor]: + """Maps projected Gaussians to intersecting tiles. + + Args: + means2d: Projected Gaussian means. [C, N, 2] if packed is False, [nnz, 2] if packed is True. + radii: Maximum radii of the projected Gaussians. [C, N] if packed is False, [nnz] if packed is True. + depths: Z-depth of the projected Gaussians. [C, N] if packed is False, [nnz] if packed is True. + tile_size: Tile size. + tile_width: Tile width. + tile_height: Tile height. + sort: If True, the returned intersections will be sorted by the intersection ids. Default: True. + packed: If True, the input tensors are packed. Default: False. + n_cameras: Number of cameras. Required if packed is True. + camera_ids: The row indices of the projected Gaussians. Required if packed is True. + gaussian_ids: The column indices of the projected Gaussians. Required if packed is True. + + Returns: + A tuple: + + - **Tiles per Gaussian**. The number of tiles intersected by each Gaussian. + Int32 [C, N] if packed is False, Int32 [nnz] if packed is True. + - **Intersection ids**. Each id is an 64-bit integer with the following + information: camera_id (Xc bits) | tile_id (Xt bits) | depth (32 bits). + Xc and Xt are the maximum number of bits required to represent the camera and + tile ids, respectively. Int64 [n_isects] + - **Flatten ids**. The global flatten indices in [C * N] or [nnz] (packed). [n_isects] + """ + if packed: + nnz = means2d.size(0) + assert means2d.shape == (nnz, 2), means2d.size() + assert radii.shape == (nnz,), radii.size() + assert depths.shape == (nnz,), depths.size() + assert camera_ids is not None, "camera_ids is required if packed is True" + assert gaussian_ids is not None, "gaussian_ids is required if packed is True" + assert n_cameras is not None, "n_cameras is required if packed is True" + camera_ids = camera_ids.contiguous() + gaussian_ids = gaussian_ids.contiguous() + C = n_cameras + + else: + C, N, _ = means2d.shape + assert means2d.shape == (C, N, 2), means2d.size() + assert radii.shape == (C, N), radii.size() + assert depths.shape == (C, N), depths.size() + + isect_ids, flatten_ids = _make_lazy_cuda_func("points_isect_tiles")( + means2d.contiguous(), + radii.contiguous(), + depths.contiguous(), + camera_ids, + gaussian_ids, + C, + tile_size, + tile_width, + tile_height, + sort, + True, # DoubleBuffer: memory efficient radixsort + ) + return isect_ids, flatten_ids + + @torch.no_grad() def isect_offset_encode( isect_ids: Tensor, n_cameras: int, tile_width: int, tile_height: int @@ -531,20 +605,414 @@ def rasterize_to_indices_in_range( image_height: Image height. tile_size: Tile size. isect_offsets: Intersection offsets outputs from `isect_offset_encode()`. [C, tile_height, tile_width] - flatten_ids: The global flatten indices in [C * N] from `isect_tiles()`. [n_isects] + flatten_ids: The global flatten indices in [C * N] from `isect_tiles()`. [n_isects] + + Returns: + A tuple: + + - **Gaussian ids**. Gaussian ids for the pixel intersection. A flattened list of shape [M]. + - **Pixel ids**. pixel indices (row-major). A flattened list of shape [M]. + - **Camera ids**. Camera indices. A flattened list of shape [M]. + """ + + C, N, _ = means2d.shape + assert conics.shape == (C, N, 3), conics.shape + assert opacities.shape == (C, N), opacities.shape + assert isect_offsets.shape[0] == C, isect_offsets.shape + + tile_height, tile_width = isect_offsets.shape[1:3] + assert ( + tile_height * tile_size >= image_height + ), f"Assert Failed: {tile_height} * {tile_size} >= {image_height}" + assert ( + tile_width * tile_size >= image_width + ), f"Assert Failed: {tile_width} * {tile_size} >= {image_width}" + + out_gauss_ids, out_indices = _make_lazy_cuda_func("rasterize_to_indices_in_range")( + range_start, + range_end, + transmittances.contiguous(), + means2d.contiguous(), + conics.contiguous(), + opacities.contiguous(), + image_width, + image_height, + tile_size, + isect_offsets.contiguous(), + flatten_ids.contiguous(), + ) + out_pixel_ids = out_indices % (image_width * image_height) + out_camera_ids = out_indices // (image_width * image_height) + return out_gauss_ids, out_pixel_ids, out_camera_ids + + +def compute_3D_smoothing_filter( + means: Tensor, # [N, 3] + viewmats: Tensor, # [C, 4, 4] + Ks: Tensor, # [C, 3, 3] + width: int, + height: int, + near_plane: float = 0.01, +) -> Tensor: + """Compute 3D smoothing filter.""" + C = viewmats.size(0) + N = means.size(0) + assert means.size() == (N, 3), means.size() + assert viewmats.size() == (C, 4, 4), viewmats.size() + assert Ks.size() == (C, 3, 3), Ks.size() + means = means.contiguous() + viewmats = viewmats.contiguous() + Ks = Ks.contiguous() + + return _Compute3DSmoothingFilter.apply( + means, + viewmats, + Ks, + width, + height, + near_plane, + ) + + +def view_to_gaussians( + means: Tensor, # [N, 3] + quats: Tensor, # [N, 4] or None + scales: Tensor, # [N, 3] or None + viewmats: Tensor, # [C, 4, 4] + radii: Tensor, # [C, N] + packed: bool = False, + sparse_grad: bool = False, +) -> Tensor: + """Projects Gaussians to 2D. + + This function fuse the process of computing covariances + (:func:`quat_scale_to_covar_preci()`), transforming to camera space (:func:`world_to_cam()`), + and perspective projection (:func:`persp_proj()`). + + .. note:: + + During projection, we ignore the Gaussians that are outside of the camera frustum. + So not all the elements in the output tensors are valid. The output `radii` could serve as + an indicator, in which zero radii means the corresponding elements are invalid in + the output tensors and will be ignored in the next rasterization process. If `packed=True`, + the output tensors will be packed into a flattened tensor, in which all elements are valid. + In this case, a `camera_ids` tensor and `gaussian_ids` tensor will be returned to indicate the + row (camera) and column (Gaussian) indices of the packed flattened tensor, which is essentially + following the COO sparse tensor format. + + .. note:: + + This functions supports projecting Gaussians with either covariances or {quaternions, scales}, + which will be converted to covariances internally in a fused CUDA kernel. Either `covars` or + {`quats`, `scales`} should be provided. + + Args: + means: Gaussian means. [N, 3] + covars: Gaussian covariances (flattened upper triangle). [N, 6] Optional. + quats: Quaternions (No need to be normalized). [N, 4] Optional. + scales: Scales. [N, 3] Optional. + viewmats: Camera-to-world matrices. [C, 4, 4] + Ks: Camera intrinsics. [C, 3, 3] + width: Image width. + height: Image height. + eps2d: A epsilon added to the 2D covariance for numerical stability. Default: 0.3. + near_plane: Near plane distance. Default: 0.01. + far_plane: Far plane distance. Default: 1e10. + radius_clip: Gaussians with projected radii smaller than this value will be ignored. Default: 0.0. + packed: If True, the output tensors will be packed into a flattened tensor. Default: False. + sparse_grad: This is only effective when `packed` is True. If True, during backward the gradients + of {`means`, `covars`, `quats`, `scales`} will be a sparse Tensor in COO layout. Default: False. + calc_compensations: If True, a view-dependent opacity compensation factor will be computed, which + is useful for anti-aliasing. Default: False. + + Returns: + A tuple: + + If `packed` is True: + + - **camera_ids**. The row indices of the projected Gaussians. Int32 tensor of shape [nnz]. + - **gaussian_ids**. The column indices of the projected Gaussians. Int32 tensor of shape [nnz]. + - **radii**. The maximum radius of the projected Gaussians in pixel unit. Int32 tensor of shape [nnz]. + - **means**. Projected Gaussian means in 2D. [nnz, 2] + - **depths**. The z-depth of the projected Gaussians. [nnz] + - **conics**. Inverse of the projected covariances. Return the flattend upper triangle with [nnz, 3] + - **compensations**. The view-dependent opacity compensation factor. [nnz] + + If `packed` is False: + + - **radii**. The maximum radius of the projected Gaussians in pixel unit. Int32 tensor of shape [C, N]. + - **means**. Projected Gaussian means in 2D. [C, N, 2] + - **depths**. The z-depth of the projected Gaussians. [C, N] + - **conics**. Inverse of the projected covariances. Return the flattend upper triangle with [C, N, 3] + - **compensations**. The view-dependent opacity compensation factor. [C, N] + """ + C = viewmats.size(0) + N = means.size(0) + assert means.size() == (N, 3), means.size() + assert viewmats.size() == (C, 4, 4), viewmats.size() + assert quats is not None, "covars or quats is required" + assert scales is not None, "covars or scales is required" + assert quats.size() == (N, 4), quats.size() + assert scales.size() == (N, 3), scales.size() + assert radii.size() == (C, N), radii.size() + means = means.contiguous() + radii = radii.contiguous() + quats = quats.contiguous() + scales = scales.contiguous() + + if sparse_grad: + assert packed, "sparse_grad is only supported when packed is True" + + viewmats = viewmats.contiguous() + + if packed: + raise NotImplementedError("packed mode is not supported") + else: + return _ViewToGaussians.apply( + means, + quats, + scales, + viewmats, + radii, + ) + + +def raytracing_to_pixels( + means2d: Tensor, # [C, N, 2] or [nnz, 2] + conics: Tensor, # [C, N, 3] or [nnz, 3] + colors: Tensor, # [C, N, channels] or [nnz, channels] + opacities: Tensor, # [C, N] or [nnz] + view2guassian: Tensor, # [C, N, 10] or [nnz, 10] + Ks: Tensor, # [C, 3, 3] + image_width: int, + image_height: int, + tile_size: int, + isect_offsets: Tensor, # [C, tile_height, tile_width] + flatten_ids: Tensor, # [n_isects] + backgrounds: Optional[Tensor] = None, # [C, channels] + packed: bool = False, + absgrad: bool = False, +) -> Tuple[Tensor, Tensor]: + """Rasterizes Gaussians to pixels. + + Args: + means2d: Projected Gaussian means. [C, N, 2] if packed is False, [nnz, 2] if packed is True. + conics: Inverse of the projected covariances with only upper triangle values. [C, N, 3] if packed is False, [nnz, 3] if packed is True. + colors: Gaussian colors or ND features. [C, N, channels] if packed is False, [nnz, channels] if packed is True. + opacities: Gaussian opacities that support per-view values. [C, N] if packed is False, [nnz] if packed is True. + image_width: Image width. + image_height: Image height. + tile_size: Tile size. + isect_offsets: Intersection offsets outputs from `isect_offset_encode()`. [C, tile_height, tile_width] + flatten_ids: The global flatten indices in [C * N] or [nnz] from `isect_tiles()`. [n_isects] + backgrounds: Background colors. [C, channels]. Default: None. + packed: If True, the input tensors are expected to be packed with shape [nnz, ...]. Default: False. + absgrad: If True, the backward pass will compute a `.absgrad` attribute for `means2d`. Default: False. + + Returns: + A tuple: + + - **Rendered colors**. [C, image_height, image_width, channels] + - **Rendered alphas**. [C, image_height, image_width, 1] + """ + assert not packed, "raytracing_to_pixels only supports non-packed mode" + + C = isect_offsets.size(0) + device = means2d.device + if packed: + nnz = means2d.size(0) + assert means2d.shape == (nnz, 2), means2d.shape + assert conics.shape == (nnz, 3), conics.shape + assert colors.shape[0] == nnz, colors.shape + assert opacities.shape == (nnz,), opacities.shape + assert view2guassian.shape == (nnz, 10), view2guassian.shape + else: + N = means2d.size(1) + assert means2d.shape == (C, N, 2), means2d.shape + assert conics.shape == (C, N, 3), conics.shape + assert colors.shape[:2] == (C, N), colors.shape + assert opacities.shape == (C, N), opacities.shape + assert view2guassian.shape == (C, N, 10), view2guassian.shape + assert Ks.shape == (C, 3, 3), Ks.shape + + if backgrounds is not None: + assert backgrounds.shape == (C, colors.shape[-1]), backgrounds.shape + backgrounds = backgrounds.contiguous() + + # Pad the channels to the nearest supported number if necessary + channels = colors.shape[-1] + if channels > 513 or channels == 0: + # TODO: maybe worth to support zero channels? + raise ValueError(f"Unsupported number of color channels: {channels}") + if channels not in ( + 1, + 2, + 3, + 4, + 5, + 8, + 9, + 16, + 17, + 32, + 33, + 64, + 65, + 128, + 129, + 256, + 257, + 512, + 513, + ): + padded_channels = (1 << (channels - 1).bit_length()) - channels + colors = torch.cat( + [ + colors, + torch.zeros(*colors.shape[:-1], padded_channels, device=device), + ], + dim=-1, + ) + if backgrounds is not None: + backgrounds = torch.cat( + [ + backgrounds, + torch.zeros( + *backgrounds.shape[:-1], padded_channels, device=device + ), + ], + dim=-1, + ) + else: + padded_channels = 0 + + tile_height, tile_width = isect_offsets.shape[1:3] + assert ( + tile_height * tile_size >= image_height + ), f"Assert Failed: {tile_height} * {tile_size} >= {image_height}" + assert ( + tile_width * tile_size >= image_width + ), f"Assert Failed: {tile_width} * {tile_size} >= {image_width}" + + render_colors, render_alphas = _RayTracingToPixels.apply( + means2d.contiguous(), + conics.contiguous(), + colors.contiguous(), + opacities.contiguous(), + view2guassian.contiguous(), + Ks.contiguous(), + backgrounds, + image_width, + image_height, + tile_size, + isect_offsets.contiguous(), + flatten_ids.contiguous(), + absgrad, + ) + + if padded_channels > 0: + render_colors = render_colors[..., :-padded_channels] + return render_colors, render_alphas + +@torch.no_grad() +def project_points( + means: Tensor, # [N, 3] + viewmats: Tensor, # [C, 4, 4] + Ks: Tensor, # [C, 3, 3] + width: int, + height: int, + near_plane: float = 0.01, + far_plane: float = 1e10, +) -> Tensor: + """Project 3D points to 2D images space""" + C = viewmats.size(0) + N = means.size(0) + assert means.size() == (N, 3), means.size() + assert viewmats.size() == (C, 4, 4), viewmats.size() + assert Ks.size() == (C, 3, 3), Ks.size() + means = means.contiguous() + viewmats = viewmats.contiguous() + Ks = Ks.contiguous() + + radii, means2d, depths = _make_lazy_cuda_func("project_points_fwd")( + means, + viewmats, + Ks, + width, + height, + near_plane, + far_plane + ) + return radii, means2d, depths + + +def integrate_to_points( + points2d: Tensor, # [C, N, 2] or [nnz, 2] + point_depths: Tensor, # [C, N] or [nnz] + means2d: Tensor, # [C, N, 2] or [nnz, 2] + conics: Tensor, # [C, N, 3] or [nnz, 3] + colors: Tensor, # [C, N, channels] or [nnz, channels] + opacities: Tensor, # [C, N] or [nnz] + view2guassian: Tensor, # [C, N, 10] or [nnz, 10] + Ks: Tensor, # [C, 3, 3] + image_width: int, + image_height: int, + tile_size: int, + isect_offsets: Tensor, # [C, tile_height, tile_width] + flatten_ids: Tensor, # [n_isects] + point_isect_offsets: Tensor, # [C, tile_height, tile_width] + point_flatten_ids: Tensor, # [n_isects] + backgrounds: Optional[Tensor] = None, # [C, channels] + packed: bool = False, + absgrad: bool = False, +) -> Tuple[Tensor, Tensor]: + """Rasterizes Gaussians to pixels. + + Args: + means2d: Projected Gaussian means. [C, N, 2] if packed is False, [nnz, 2] if packed is True. + conics: Inverse of the projected covariances with only upper triangle values. [C, N, 3] if packed is False, [nnz, 3] if packed is True. + colors: Gaussian colors or ND features. [C, N, channels] if packed is False, [nnz, channels] if packed is True. + opacities: Gaussian opacities that support per-view values. [C, N] if packed is False, [nnz] if packed is True. + image_width: Image width. + image_height: Image height. + tile_size: Tile size. + isect_offsets: Intersection offsets outputs from `isect_offset_encode()`. [C, tile_height, tile_width] + flatten_ids: The global flatten indices in [C * N] or [nnz] from `isect_tiles()`. [n_isects] + backgrounds: Background colors. [C, channels]. Default: None. + packed: If True, the input tensors are expected to be packed with shape [nnz, ...]. Default: False. + absgrad: If True, the backward pass will compute a `.absgrad` attribute for `means2d`. Default: False. Returns: A tuple: - - **Gaussian ids**. Gaussian ids for the pixel intersection. A flattened list of shape [M]. - - **Pixel ids**. pixel indices (row-major). A flattened list of shape [M]. - - **Camera ids**. Camera indices. A flattened list of shape [M]. + - **Rendered colors**. [C, image_height, image_width, channels] + - **Rendered alphas**. [C, image_height, image_width, 1] """ + assert not packed, "integrate_to_points only supports non-packed mode." - C, N, _ = means2d.shape - assert conics.shape == (C, N, 3), conics.shape - assert opacities.shape == (C, N), opacities.shape - assert isect_offsets.shape[0] == C, isect_offsets.shape + C = isect_offsets.size(0) + device = means2d.device + if packed: + nnz = means2d.size(0) + assert means2d.shape == (nnz, 2), means2d.shape + assert conics.shape == (nnz, 3), conics.shape + assert colors.shape[0] == nnz, colors.shape + assert opacities.shape == (nnz,), opacities.shape + assert view2guassian.shape == (nnz, 10), view2guassian.shape + else: + N = means2d.size(1) + assert means2d.shape == (C, N, 2), means2d.shape + assert conics.shape == (C, N, 3), conics.shape + assert colors.shape[:2] == (C, N), colors.shape + assert opacities.shape == (C, N), opacities.shape + assert view2guassian.shape == (C, N, 10), view2guassian.shape + assert Ks.shape == (C, 3, 3), Ks.shape + PN = points2d.size(1) + assert points2d.shape == (C, PN, 2), points2d.shape + assert point_depths.shape == (C, PN), point_depths.shape + if backgrounds is not None: + assert backgrounds.shape == (C, colors.shape[-1]), backgrounds.shape + backgrounds = backgrounds.contiguous() tile_height, tile_width = isect_offsets.shape[1:3] assert ( @@ -554,23 +1022,29 @@ def rasterize_to_indices_in_range( tile_width * tile_size >= image_width ), f"Assert Failed: {tile_width} * {tile_size} >= {image_width}" - out_gauss_ids, out_indices = _make_lazy_cuda_func("rasterize_to_indices_in_range")( - range_start, - range_end, - transmittances.contiguous(), - means2d.contiguous(), - conics.contiguous(), - opacities.contiguous(), - image_width, - image_height, - tile_size, - isect_offsets.contiguous(), - flatten_ids.contiguous(), - ) - out_pixel_ids = out_indices % (image_width * image_height) - out_camera_ids = out_indices // (image_width * image_height) - return out_gauss_ids, out_pixel_ids, out_camera_ids + render_colors, render_alphas = _make_lazy_cuda_func( + "integrate_to_points_fwd" + )( + points2d, + point_depths, + means2d, + conics, + colors, + opacities, + view2guassian, + Ks, + backgrounds, + image_width, + image_height, + tile_size, + isect_offsets, + flatten_ids, + point_isect_offsets, + point_flatten_ids, + ) + + return render_colors, render_alphas class _QuatScaleToCovarPreci(torch.autograd.Function): """Converts quaternions and scales to covariance and precision matrices.""" @@ -927,6 +1401,209 @@ def backward( ) +class _RayTracingToPixels(torch.autograd.Function): + """Ray tracing gaussians""" + + @staticmethod + def forward( + ctx, + means2d: Tensor, # [C, N, 2] + conics: Tensor, # [C, N, 3] + colors: Tensor, # [C, N, D] + opacities: Tensor, # [C, N] + view2guassian: Tensor, # [C, N, 10] + Ks: Tensor, # [C, 3, 3] + backgrounds: Tensor, # [C, D], Optional + width: int, + height: int, + tile_size: int, + isect_offsets: Tensor, # [C, tile_height, tile_width] + flatten_ids: Tensor, # [n_isects] + absgrad: bool, + ) -> Tuple[Tensor, Tensor]: + render_colors, render_alphas, last_ids = _make_lazy_cuda_func( + "raytracing_to_pixels_fwd" + )( + means2d, + conics, + colors, + opacities, + view2guassian, + Ks, + backgrounds, + width, + height, + tile_size, + isect_offsets, + flatten_ids, + ) + + ctx.save_for_backward( + means2d, + conics, + colors, + opacities, + view2guassian, + Ks, + backgrounds, + isect_offsets, + flatten_ids, + render_alphas, + last_ids, + ) + ctx.width = width + ctx.height = height + ctx.tile_size = tile_size + ctx.absgrad = absgrad + + # double to float + render_alphas = render_alphas.float() + return render_colors, render_alphas + + @staticmethod + def backward( + ctx, + v_render_colors: Tensor, # [C, H, W, 3] + v_render_alphas: Tensor, # [C, H, W, 1] + ): + ( + means2d, + conics, + colors, + opacities, + view2guassian, + Ks, + backgrounds, + isect_offsets, + flatten_ids, + render_alphas, + last_ids, + ) = ctx.saved_tensors + width = ctx.width + height = ctx.height + tile_size = ctx.tile_size + absgrad = ctx.absgrad + + ( + v_means2d_abs, + v_means2d, + v_conics, + v_colors, + v_opacities, + v_view2guassian, + ) = _make_lazy_cuda_func("raytracing_to_pixels_bwd")( + means2d, + conics, + colors, + opacities, + view2guassian, + Ks, + backgrounds, + width, + height, + tile_size, + isect_offsets, + flatten_ids, + render_alphas, + last_ids, + v_render_colors.contiguous(), + v_render_alphas.contiguous(), + absgrad, + ) + + if absgrad: + means2d.absgrad = v_means2d_abs + + if ctx.needs_input_grad[6]: + v_backgrounds = (v_render_colors * (1.0 - render_alphas).float()).sum( + dim=(1, 2) + ) + else: + v_backgrounds = None + + return ( + v_means2d, + v_conics, + v_colors, + v_opacities, + v_view2guassian, + None, + v_backgrounds, + None, + None, + None, + None, + None, + None, + ) + + +class _ViewToGaussians(torch.autograd.Function): + """Compute View to Gaussians.""" + + @staticmethod + def forward( + ctx, + means: Tensor, # [N, 3] + quats: Tensor, # [N, 4] + scales: Tensor, # [N, 3] + viewmats: Tensor, # [C, 4, 4] + radii: Tensor, # [C, N] + ) -> Tensor: + + view2gaussians = _make_lazy_cuda_func("view_to_gaussians_fwd")( + means, + quats, + scales, + viewmats, + radii, + ) + + ctx.save_for_backward(means, quats, scales, viewmats, radii, view2gaussians) + + return view2gaussians + + @staticmethod + def backward(ctx, v_view2gaussians): + ( + means, + quats, + scales, + viewmats, + radii, + view2gaussians, + ) = ctx.saved_tensors + + v_means, v_quats, v_scales, v_viewmats = _make_lazy_cuda_func( + "view_to_gaussians_bwd" + )( + means, + quats, + scales, + viewmats, + radii, + view2gaussians.contiguous(), + v_view2gaussians.contiguous(), + ctx.needs_input_grad[3], # viewmats_requires_grad + ) + if not ctx.needs_input_grad[0]: + v_means = None + if not ctx.needs_input_grad[1]: + v_quats = None + if not ctx.needs_input_grad[2]: + v_scales = None + if not ctx.needs_input_grad[3]: + v_viewmats = None + + return ( + v_means, + v_quats, + v_scales, + v_viewmats, + None, + ) + + class _FullyFusedProjectionPacked(torch.autograd.Function): """Projects Gaussians to 2D. Return packed tensors.""" @@ -1143,3 +1820,28 @@ def backward(ctx, v_colors: Tensor): if not compute_v_dirs: v_dirs = None return None, v_dirs, v_coeffs, None + + +class _Compute3DSmoothingFilter(torch.autograd.Function): + """Compute 3D Smoothing filter.""" + + @staticmethod + def forward( + ctx, + means: Tensor, # [N, 3] + viewmats: Tensor, # [C, 4, 4] + Ks: Tensor, # [C, 3, 3] + width: int, + height: int, + near_plane: float, + ) -> Tensor: + filter_3D = _make_lazy_cuda_func("compute_3D_smoothing_filter_fwd")( + means, + viewmats, + Ks, + width, + height, + near_plane, + ) + + return filter_3D diff --git a/gsplat/cuda/csrc/bindings.h b/gsplat/cuda/csrc/bindings.h index c983f461e..cf9c625ad 100644 --- a/gsplat/cuda/csrc/bindings.h +++ b/gsplat/cuda/csrc/bindings.h @@ -108,6 +108,16 @@ isect_tiles_tensor(const torch::Tensor &means2d, // [C, N, 2] or [nnz, 2] const uint32_t tile_width, const uint32_t tile_height, const bool sort, const bool double_buffer); +std::tuple +points_isect_tiles_tensor(const torch::Tensor &means2d, // [C, N, 2] or [nnz, 2] + const torch::Tensor &radii, // [C, N] or [nnz] + const torch::Tensor &depths, // [C, N] or [nnz] + const at::optional &camera_ids, // [nnz] + const at::optional &gaussian_ids, // [nnz] + const uint32_t C, const uint32_t tile_size, + const uint32_t tile_width, const uint32_t tile_height, + const bool sort, const bool double_buffer); + torch::Tensor isect_offset_encode_tensor(const torch::Tensor &isect_ids, // [n_isects] const uint32_t C, const uint32_t tile_width, const uint32_t tile_height); @@ -126,6 +136,29 @@ std::tuple rasterize_to_pixels_fwd_ const torch::Tensor &flatten_ids // [n_isects] ); +torch::Tensor view_to_gaussians_fwd_tensor( + const torch::Tensor &means, // [N, 3] + const torch::Tensor &quats, // [N, 4] + const torch::Tensor &scales, // [N, 3] + const torch::Tensor &viewmats, // [C, 4, 4] + const torch::Tensor &radii // [C, N] +); + +std::tuple +view_to_gaussians_bwd_tensor( + // fwd inputs + const torch::Tensor &means, // [N, 3] + const torch::Tensor &quats, // [N, 4] + const torch::Tensor &scales, // [N, 3] + const torch::Tensor &viewmats, // [C, 4, 4] + const torch::Tensor &radii, // [C, N] + // fwd outputs + const torch::Tensor &view2gaussians, // [C, N, 10] + // grad outputs + const torch::Tensor &v_view2gaussians, // [C, N, 10] + const bool viewmats_requires_grad +); + std::tuple rasterize_to_pixels_bwd_tensor( // Gaussian parameters @@ -148,6 +181,46 @@ rasterize_to_pixels_bwd_tensor( // options bool absgrad); +std::tuple raytracing_to_pixels_fwd_tensor( + // Gaussian parameters + const torch::Tensor &means2d, // [C, N, 2] + const torch::Tensor &conics, // [C, N, 3] + const torch::Tensor &colors, // [C, N, D] + const torch::Tensor &opacities, // [N] + const torch::Tensor &view2gaussians, // [C, N, 10] + const torch::Tensor &Ks, // [C, 3, 3] + const at::optional &backgrounds, // [C, D] + // image size + const uint32_t image_width, const uint32_t image_height, const uint32_t tile_size, + // intersections + const torch::Tensor &tile_offsets, // [C, tile_height, tile_width] + const torch::Tensor &flatten_ids // [n_isects] +); + +std::tuple +raytracing_to_pixels_bwd_tensor( + // Gaussian parameters + const torch::Tensor &means2d, // [C, N, 2] + const torch::Tensor &conics, // [C, N, 3] + const torch::Tensor &colors, // [C, N, 3] + const torch::Tensor &opacities, // [N] + const torch::Tensor &view2gaussians, // [C, N, 10] + const torch::Tensor &Ks, // [C, 3, 3] + const at::optional &backgrounds, // [C, 3] + // image size + const uint32_t image_width, const uint32_t image_height, const uint32_t tile_size, + // intersections + const torch::Tensor &tile_offsets, // [C, tile_height, tile_width] + const torch::Tensor &flatten_ids, // [n_isects] + // forward outputs + const torch::Tensor &render_alphas, // [C, image_height, image_width, 1] + const torch::Tensor &last_ids, // [C, image_height, image_width] + // gradients of outputs + const torch::Tensor &v_render_colors, // [C, image_height, image_width, 3] + const torch::Tensor &v_render_alphas, // [C, image_height, image_width, 1] + // options + bool absgrad); + std::tuple rasterize_to_indices_in_range_tensor( const uint32_t range_start, const uint32_t range_end, // iteration steps const torch::Tensor transmittances, // [C, image_height, image_width] @@ -175,6 +248,44 @@ compute_sh_bwd_tensor(const uint32_t K, const uint32_t degrees_to_use, torch::Tensor &v_colors, // [..., 3] bool compute_v_dirs); +torch::Tensor compute_3D_smoothing_filter_fwd_tensor( + const torch::Tensor &means, // [N, 3] + const torch::Tensor &viewmats, // [C, 4, 4] + const torch::Tensor &Ks, // [C, 3, 3] + const uint32_t image_width, const uint32_t image_height, + const float near_plane); + +std::tuple +project_points_fwd_tensor( + const torch::Tensor &means, // [N, 3] + const torch::Tensor &viewmats, // [C, 4, 4] + const torch::Tensor &Ks, // [C, 3, 3] + const uint32_t image_width, const uint32_t image_height, + const float near_plane, const float far_plane); + +std::tuple +integrate_to_points_fwd_tensor( + // Point parameters + const torch::Tensor &points2d, // [C, N, 2] + const torch::Tensor &point_depths, // [C, N, 3] + // Gaussian parameters + const torch::Tensor &means2d, // [C, N, 2] + const torch::Tensor &conics, // [C, N, 3] + const torch::Tensor &colors, // [C, N, D] + const torch::Tensor &opacities, // [N] + const torch::Tensor &view2gaussians, // [C, N, 10] + const torch::Tensor &Ks, // [C, 3, 3] + const at::optional &backgrounds, // [C, D] + // image size + const uint32_t image_width, const uint32_t image_height, const uint32_t tile_size, + // intersections + const torch::Tensor &tile_offsets, // [C, tile_height, tile_width] + const torch::Tensor &flatten_ids, // [n_isects] + // points intersections + const torch::Tensor &point_tile_offsets, // [C, tile_height, tile_width] + const torch::Tensor &point_flatten_ids // [n_isects] +); + /**************************************************************************************** * Packed Version ****************************************************************************************/ diff --git a/gsplat/cuda/csrc/compute_3D_smoothing_filter_fwd.cu b/gsplat/cuda/csrc/compute_3D_smoothing_filter_fwd.cu new file mode 100644 index 000000000..a5f337aa0 --- /dev/null +++ b/gsplat/cuda/csrc/compute_3D_smoothing_filter_fwd.cu @@ -0,0 +1,110 @@ +#include "bindings.h" +#include "helpers.cuh" +#include "utils.cuh" + +#include +#include +#include +#include +#include + +namespace cg = cooperative_groups; + + +/**************************************************************************** + * Compute the 3D smoothing filter size of 3D Gaussians Forward Pass + ****************************************************************************/ + +template +__global__ void +compute_3D_smoothing_filter_fwd_kernel(const uint32_t C, const uint32_t N, + const T *__restrict__ means, // [N, 3] + const T *__restrict__ viewmats, // [C, 4, 4] + const T *__restrict__ Ks, // [C, 3, 3] + const int32_t image_width, const int32_t image_height, + const T near_plane, + // outputs + T *__restrict__ filter // [N, ] +) { + // parallelize over C * N. + uint32_t idx = cg::this_grid().thread_rank(); + if (idx >= C * N) { + return; + } + const uint32_t cid = idx / N; // camera id + const uint32_t gid = idx % N; // gaussian id + + // shift pointers to the current camera and gaussian + means += gid * 3; + viewmats += cid * 16; + Ks += cid * 9; + + // glm is column-major but input is row-major + mat3 R = mat3(viewmats[0], viewmats[4], viewmats[8], // 1st column + viewmats[1], viewmats[5], viewmats[9], // 2nd column + viewmats[2], viewmats[6], viewmats[10] // 3rd column + ); + vec3 t = vec3(viewmats[3], viewmats[7], viewmats[11]); + + // transform Gaussian center to camera space + vec3 mean_c; + pos_world_to_cam(R, t, glm::make_vec3(means), mean_c); + if (mean_c.z < near_plane ) { + return; + } + + // project the point to image plane + vec2 mean2d; + + const T fx = Ks[0]; + const T fy = Ks[4]; + const T cx = Ks[2]; + const T cy = Ks[5]; + + T x = mean_c[0], y = mean_c[1], z = mean_c[2]; + T rz = 1.f / z; + mean2d = vec2({fx * x * rz + cx, fy * y * rz + cy}); + + // mask out gaussians outside the image region + if (mean2d.x <= 0 || mean2d.x >= image_width || + mean2d.y <= 0 || mean2d.y >= image_height) { + return; + } + + T filter_size = z / fx; + + // write to outputs + // atomicMin(&filter[gid], filter_size); + + // atomicMin is not supported for float, so we use __float_as_int + // refer to https://stackoverflow.com/questions/17399119/how-do-i-use-atomicmax-on-floating-point-values-in-cuda/51549250#51549250 + atomicMin((int *)&filter[gid], __float_as_int(filter_size)); +} + + +torch::Tensor compute_3D_smoothing_filter_fwd_tensor( + const torch::Tensor &means, // [N, 3] + const torch::Tensor &viewmats, // [C, 4, 4] + const torch::Tensor &Ks, // [C, 3, 3] + const uint32_t image_width, const uint32_t image_height, + const float near_plane) { + DEVICE_GUARD(means); + CHECK_INPUT(means); + CHECK_INPUT(viewmats); + CHECK_INPUT(Ks); + + uint32_t N = means.size(0); // number of gaussians + uint32_t C = viewmats.size(0); // number of cameras + at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream(); + + torch::Tensor filter = torch::full({N}, 1000000, means.options()); + + if (C && N) { + compute_3D_smoothing_filter_fwd_kernel<<<(C * N + N_THREADS - 1) / N_THREADS, N_THREADS, 0, stream>>>( + C, N, means.data_ptr(), + viewmats.data_ptr(), Ks.data_ptr(), image_width, image_height, + near_plane, + filter.data_ptr()); + } + return filter; +} diff --git a/gsplat/cuda/csrc/ext.cpp b/gsplat/cuda/csrc/ext.cpp index 6ef34e7e1..d545afed1 100644 --- a/gsplat/cuda/csrc/ext.cpp +++ b/gsplat/cuda/csrc/ext.cpp @@ -17,14 +17,26 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("fully_fused_projection_fwd", &fully_fused_projection_fwd_tensor); m.def("fully_fused_projection_bwd", &fully_fused_projection_bwd_tensor); + m.def("view_to_gaussians_fwd", &view_to_gaussians_fwd_tensor); + m.def("view_to_gaussians_bwd", &view_to_gaussians_bwd_tensor); + m.def("isect_tiles", &isect_tiles_tensor); m.def("isect_offset_encode", &isect_offset_encode_tensor); + m.def("points_isect_tiles", &points_isect_tiles_tensor); + m.def("rasterize_to_pixels_fwd", &rasterize_to_pixels_fwd_tensor); m.def("rasterize_to_pixels_bwd", &rasterize_to_pixels_bwd_tensor); + m.def("raytracing_to_pixels_fwd", &raytracing_to_pixels_fwd_tensor); + m.def("raytracing_to_pixels_bwd", &raytracing_to_pixels_bwd_tensor); + m.def("rasterize_to_indices_in_range", &rasterize_to_indices_in_range_tensor); + m.def("compute_3D_smoothing_filter_fwd", &compute_3D_smoothing_filter_fwd_tensor); + m.def("project_points_fwd", &project_points_fwd_tensor); + m.def("integrate_to_points_fwd", &integrate_to_points_fwd_tensor); + // packed version m.def("fully_fused_projection_packed_fwd", &fully_fused_projection_packed_fwd_tensor); m.def("fully_fused_projection_packed_bwd", &fully_fused_projection_packed_bwd_tensor); diff --git a/gsplat/cuda/csrc/integrate_to_points_fwd.cu b/gsplat/cuda/csrc/integrate_to_points_fwd.cu new file mode 100644 index 000000000..d28420cd4 --- /dev/null +++ b/gsplat/cuda/csrc/integrate_to_points_fwd.cu @@ -0,0 +1,621 @@ +#include "bindings.h" +#include "helpers.cuh" +#include "types.cuh" +#include +#include +#include + +namespace cg = cooperative_groups; + +/**************************************************************************** + * integrate to Points Forward Pass + * Current implementation is the same as GOF but it could be potentially better parallelized if we distribute the points uniformly instead of using the thread the it projected into + ****************************************************************************/ + +template +__global__ void integrate_to_points_fwd_kernel( + const uint32_t C, const uint32_t N, const uint32_t n_isects, + const uint32_t PN, const uint32_t point_n_isects, + const bool packed, + const vec2 *__restrict__ points2d, // [C, N, 2] or [nnz, 2] + const S *__restrict__ point_depths, // [C, N, 3] or [nnz, 3] + const vec2 *__restrict__ means2d, // [C, N, 2] or [nnz, 2] + const vec3 *__restrict__ conics, // [C, N, 3] or [nnz, 3] + const S *__restrict__ colors, // [C, N, COLOR_DIM] or [nnz, COLOR_DIM] + const S *__restrict__ opacities, // [C, N] or [nnz] + const vec10 *__restrict__ view2gaussians, // [C, N, 10] or [nnz, 10] + const S *__restrict__ Ks, // [C, 3, 3] + const S *__restrict__ backgrounds, // [C, COLOR_DIM] + const uint32_t image_width, const uint32_t image_height, const uint32_t tile_size, + const uint32_t tile_width, const uint32_t tile_height, + const int32_t *__restrict__ tile_offsets, // [C, tile_height, tile_width] + const int32_t *__restrict__ flatten_ids, // [n_isects] + const int32_t *__restrict__ point_tile_offsets, // [C, tile_height, tile_width] + const int32_t *__restrict__ point_flatten_ids, // [n_isects] + S *__restrict__ render_colors, // [C, image_height, image_width, COLOR_DIM] + S *__restrict__ render_alphas, // [C, image_height, image_width, 1] + S *__restrict__ out_integrated_alphas, // [C, PN] + int32_t *__restrict__ last_ids // [C, image_height, image_width] +) { + // each thread draws one pixel, but also timeshares caching gaussians in a + // shared tile + + auto block = cg::this_thread_block(); + int32_t camera_id = block.group_index().x; + int32_t tile_id = block.group_index().y * tile_width + block.group_index().z; + uint32_t i = block.group_index().y * tile_size + block.thread_index().y; + uint32_t j = block.group_index().z * tile_size + block.thread_index().x; + + tile_offsets += camera_id * tile_height * tile_width; + render_colors += camera_id * image_height * image_width * (COLOR_DIM + 1 + 3); + render_alphas += camera_id * image_height * image_width; + out_integrated_alphas += camera_id * PN; + last_ids += camera_id * image_height * image_width * 2; + Ks += camera_id * 9; + + if (backgrounds != nullptr) { + backgrounds += camera_id * COLOR_DIM; + } + + S px = (S)j + 0.5f; + S py = (S)i + 0.5f; + int32_t pix_id = i * image_width + j; + + const S focal_x = Ks[0]; + const S focal_y = Ks[4]; + const S cx = Ks[2]; + const S cy = Ks[5]; + const vec3 ray = {(px - cx) / focal_x, (py - cy) / focal_y, 1.0}; + + // return if out of bounds + // keep not rasterizing threads around for reading data + bool inside = (i < image_height && j < image_width); + bool done = !inside; + + const uint32_t block_size = block.size(); + // have all threads in tile process the same gaussians in batches + // first collect gaussians between range.x and range.y in batches + // which gaussians to look through in this tile + int32_t range_start = tile_offsets[tile_id]; + int32_t range_end = + (camera_id == C - 1) && (tile_id == tile_width * tile_height - 1) + ? n_isects + : tile_offsets[tile_id + 1]; + uint32_t num_batches = (range_end - range_start + block_size - 1) / block_size; + int32_t point_range_start = point_tile_offsets[tile_id]; + int32_t point_range_end = + (camera_id == C - 1) && (tile_id == tile_width * tile_height - 1) + ? point_n_isects + : point_tile_offsets[tile_id + 1]; + uint32_t point_num_batches = (point_range_end - point_range_start + block_size - 1) / block_size; + + extern __shared__ int s[]; + int32_t *id_batch = (int32_t *)s; // [block_size] + vec3 *xy_opacity_batch = + reinterpret_cast *>(&id_batch[block_size]); // [block_size] + vec3 *conic_batch = + reinterpret_cast *>(&xy_opacity_batch[block_size]); // [block_size] + vec10 *view2gaussian_batch = + reinterpret_cast *>(&conic_batch[block_size]); // [block_size] + vec3 *points2d_depth_batch = + reinterpret_cast *>(&view2gaussian_batch[block_size]); // [block_size] + + // current visibility left to render + // transmittance is gonna be used in the backward pass which requires a high + // numerical precision so we use double for it. However double make bwd 1.5x slower + // so we stick with float for now. + S T = 1.0f; + // index of most recent gaussian to write to this thread's pixel + uint32_t last_contributor = 0; + uint32_t max_contributor = -1; + // collect and process batches of gaussians + // each thread loads one gaussian at a time before rasterizing its + // designated pixel + uint32_t tr = block.thread_rank(); + + // + 1 for depth and + 3 for normal + S pix_out[COLOR_DIM + 1 + 3] = {0.f}; + + #define MAX_NUM_CONTRIBUTORS 256 + + uint32_t n_contrib_local = 0; + uint16_t contributed_ids[MAX_NUM_CONTRIBUTORS*4] = { 0 }; + // use 4 additional corner points so that we have more accurate estimation of contributed_ids + S corner_Ts[5] = { 1.0f, 1.0f, 1.0f, 1.0f, 1.0f }; + S offset_xs[5] = { 0.0f, -0.5f, 0.5f, -0.5f, 0.5f }; + S offset_ys[5] = { 0.0f, -0.5f, -0.5f, 0.5f, 0.5f }; + + for (uint32_t b = 0; b < num_batches; ++b) { + // resync all threads before beginning next batch + // end early if entire tile is done + if (__syncthreads_count(done) >= block_size) { + break; + } + + // each thread fetch 1 gaussian from front to back + // index of gaussian to load + uint32_t batch_start = range_start + block_size * b; + uint32_t idx = batch_start + tr; + if (idx < range_end) { + int32_t g = flatten_ids[idx]; // flatten index in [C * N] or [nnz] + id_batch[tr] = g; + const vec2 xy = means2d[g]; + const S opac = opacities[g]; + xy_opacity_batch[tr] = {xy.x, xy.y, opac}; + conic_batch[tr] = conics[g]; + view2gaussian_batch[tr] = view2gaussians[g]; + } + + // wait for other threads to collect the gaussians in batch + block.sync(); + + // process gaussians in the current batch for this pixel + uint32_t batch_size = min(block_size, range_end - batch_start); + for (uint32_t t = 0; (t < batch_size) && !done; ++t) { + const vec3 conic = conic_batch[t]; + const vec3 xy_opac = xy_opacity_batch[t]; + const S opac = xy_opac.z; + // why not use pointer so we don't need to copy again? + const vec10 view2gaussian = view2gaussian_batch[t]; + + bool used = false; + for (uint32_t k = 0; k < 5; ++k){ + const vec3 ray_k = {(px + offset_xs[k] - cx) / focal_x, (py +offset_ys[k] - cy) / focal_y, 1.0}; + + const vec3 normal = { + view2gaussian[0] * ray_k.x + view2gaussian[1] * ray_k.y + view2gaussian[2], + view2gaussian[1] * ray_k.x + view2gaussian[3] * ray_k.y + view2gaussian[4], + view2gaussian[2] * ray_k.x + view2gaussian[4] * ray_k.y + view2gaussian[5] + }; + + // use AA, BB, CC so that the name is unique + S AA = ray_k.x * normal[0] + ray_k.y * normal[1] + normal[2]; + S BB = 2 * (view2gaussian[6] * ray_k.x + view2gaussian[7] * ray_k.y + view2gaussian[8]); + S CC = view2gaussian[9]; + + // t is the depth of the gaussian + S depth = -BB/(2*AA); + + //TODO take near plane as input + #define NEAR_PLANE 0.01f + // depth must be positive otherwise it is not valid and we skip it + if (depth <= NEAR_PLANE) + continue; + + // the scale of the gaussian is 1.f / sqrt(AA) + S min_value = -(BB/AA) * (BB/4.) + CC; + + S power = -0.5f * min_value; + if (power > 0.0f){ + power = 0.0f; + } + + S alpha = min(0.999f, opac * exp(power)); + + if (alpha < 1.f / 255.f) { + continue; + } + + const S next_T = corner_Ts[k] * (1.0f - alpha); + if (next_T <= 1e-4) { // this pixel is done: exclusive + // done = true; + continue; + } + + int32_t g = id_batch[t]; + const S vis = alpha * corner_Ts[k]; + if (k == 0){ + const S *c_ptr = colors + g * COLOR_DIM; + PRAGMA_UNROLL + for (uint32_t ch = 0; ch < COLOR_DIM; ++ch) { + pix_out[ch] += c_ptr[ch] * vis; + } + } + + last_contributor = batch_start + t - range_start; + + // if (pix_id == 550 * 1237 + 600 && k == 0){ + // printf("contributed_ids for middle pixel: %d\n", last_contributor); + // } + + corner_Ts[k] = next_T; + used = true; + } + if (used){ + + contributed_ids[n_contrib_local] = (u_int16_t)last_contributor; + n_contrib_local += 1; + + if (n_contrib_local >= MAX_NUM_CONTRIBUTORS * 4){ + done = true; + printf("ERROR: Maximal contributors are met. This should be fixed! %d\n", n_contrib_local); + break; + } + // if (pix_id == 550 * 1237 + 600){ + // printf("contributed_ids: %d\n", last_contributor); + // } + } + + } + } + + if (inside) { + // Here T is the transmittance AFTER the last gaussian in this pixel. + // We (should) store double precision as T would be used in backward pass and + // it can be very small and causing large diff in gradients with float32. + // However, double precision makes the backward pass 1.5x slower so we stick + // with float for now. + render_alphas[pix_id] = 1.0f - T; + PRAGMA_UNROLL + for (uint32_t k = 0; k < COLOR_DIM; ++k) { + render_colors[pix_id * (COLOR_DIM + 1 + 3) + k] = + backgrounds == nullptr ? pix_out[k] : (pix_out[k] + T * backgrounds[k]); + } + } + + // return; + + // Allocate storage for batches of collectively fetched data. + #define MAX_NUM_PROJECTED 256 + int32_t projected_ids[MAX_NUM_PROJECTED] = { 0 }; + float2 projected_xy[MAX_NUM_PROJECTED] = { 0.f }; + float projected_depth[MAX_NUM_PROJECTED] = { 0.f }; + + //TODO add a for loop here in case we got more points than MAX_NUM_PROJECTED + uint32_t point_counter_last = 0; + bool point_done = !inside; + int total_projected = 0; + + while (true){ + // resync all threads before beginning next batch + // end early if entire tile is done + if (__syncthreads_count(point_done) >= block_size) { + break; + } + + int num_projected = 0; + bool excced_max_projected = false; + done = false; + + uint32_t point_counter = 0; + + // check how many points are projected to this pixel + for (uint32_t b = 0; b < point_num_batches; ++b) { + // resync all threads before beginning next batch + // end early if entire tile is done + if (__syncthreads_count(done) >= block_size) { + break; + } + + // each thread fetch 1 gaussian from front to back + // index of gaussian to load + uint32_t point_batch_start = point_range_start + block_size * b; + uint32_t idx = point_batch_start + tr; + if (idx < point_range_end) { + int32_t g = point_flatten_ids[idx]; // flatten index in [C * N] or [nnz] + id_batch[tr] = g; + const vec2 xy = points2d[g]; + const S depth = point_depths[g]; + points2d_depth_batch[tr] = {xy.x, xy.y, depth}; + } + + // wait for other threads to collect the gaussians in batch + block.sync(); + + // process gaussians in the current batch for this pixel + uint32_t batch_size = min(block_size, point_range_end - point_batch_start); + for (uint32_t t = 0; (t < batch_size) && !done; ++t) { + point_counter++; + if (point_counter <= point_counter_last){ + continue; + } + const vec3 point_xy_depth = points2d_depth_batch[t]; + const float2 point_xy = {point_xy_depth.x, point_xy_depth.y}; + const float depth = point_xy_depth.z; + if ((point_xy.x >= (px - 0.5)) && (point_xy.x < (px + 0.5)) && + (point_xy.y >= (py - 0.5)) && (point_xy.y < (py + 0.5))){ + //TODO maybe add a check for depth to filter some points + if (true ){ + + if (num_projected >= MAX_NUM_PROJECTED){ + done = true; + excced_max_projected = true; + break; + } + + projected_ids[num_projected] = id_batch[t]; + projected_xy[num_projected] = point_xy; + projected_depth[num_projected] = depth; + num_projected += 1; + } + // if (pix_id == 550 * 1237 + 600){ + // printf("projected_points_ids: %d points_xy_depth: %.10f %.10f %.10f\n", id_batch[t], point_xy.x, point_xy.y, depth); + // } + // ((points2d[0, :, 0:1] >= 600) & (points2d[0, :, 0:1] <= 601) & (points2d[0, :, 1:2] >= 401) & (points2d[0, :, 1:2] <=402) ).float().argmax() + } + + } + } + // return; + + point_counter_last = point_counter - 1; + point_done = !excced_max_projected; + total_projected += num_projected; + + // reiterate all primitives + //TODO we could allocate the memory with dynamic size + float point_alphas[MAX_NUM_PROJECTED] = { 0.f}; + float point_Ts[MAX_NUM_PROJECTED] = {0.f}; + for (int i = 0; i < num_projected; i++){ + point_Ts[i] = 1.f; + } + + uint32_t num_iterated = 0; + // bool second_done = !inside; + done = !inside; + uint16_t num_contributed_second = 0; + + for (uint32_t b = 0; b < num_batches; ++b) { + // resync all threads before beginning next batch + // end early if entire tile is done + if (__syncthreads_count(done) >= block_size) { + break; + } + + // each thread fetch 1 gaussian from front to back + // index of gaussian to load + uint32_t batch_start = range_start + block_size * b; + uint32_t idx = batch_start + tr; + if (idx < range_end) { + int32_t g = flatten_ids[idx]; // flatten index in [C * N] or [nnz] + id_batch[tr] = g; + const vec2 xy = means2d[g]; + const S opac = opacities[g]; + xy_opacity_batch[tr] = {xy.x, xy.y, opac}; + conic_batch[tr] = conics[g]; + view2gaussian_batch[tr] = view2gaussians[g]; + } + + // wait for other threads to collect the gaussians in batch + block.sync(); + + // process gaussians in the current batch for this pixel + uint32_t batch_size = min(block_size, range_end - batch_start); + for (uint32_t t = 0; (t < batch_size) && !done; ++t) { + uint32_t cur_idx = batch_start + t - range_start; + + if (cur_idx > last_contributor){ + done = true; + continue; + } + + if (cur_idx != (uint32_t)contributed_ids[num_contributed_second]){ + continue; + } else{ + num_contributed_second += 1; + } + + const vec3 conic = conic_batch[t]; + const vec3 xy_opac = xy_opacity_batch[t]; + const S opac = xy_opac.z; + // why not use pointer so we don't need to copy again? + const vec10 view2gaussian = view2gaussian_batch[t]; + + for (uint32_t k = 0; k < num_projected; ++k){ + // if (pix_id == 550 * 1237 + 600 && k == 0){ + // printf("ealpha last_contributor: %d num_iterated: %d\n", batch_start + t - range_start, cur_idx); + // } + + const vec3 ray_k = {(projected_xy[k].x - cx) / focal_x, (projected_xy[k].y - cy) / focal_y, 1.0}; + const S ray_depth = projected_depth[k]; + + const vec3 normal = { + view2gaussian[0] * ray_k.x + view2gaussian[1] * ray_k.y + view2gaussian[2], + view2gaussian[1] * ray_k.x + view2gaussian[3] * ray_k.y + view2gaussian[4], + view2gaussian[2] * ray_k.x + view2gaussian[4] * ray_k.y + view2gaussian[5] + }; + + // use AA, BB, CC so that the name is unique + S AA = ray_k.x * normal[0] + ray_k.y * normal[1] + normal[2]; + S BB = 2 * (view2gaussian[6] * ray_k.x + view2gaussian[7] * ray_k.y + view2gaussian[8]); + S CC = view2gaussian[9]; + + // t is the depth of the gaussian + S depth = -BB/(2*AA); + + if (depth > ray_depth){ + depth = ray_depth; + } + + S power = -0.5f * (AA * depth * depth + BB * depth + CC); + if (power > 0.0f){ + power = 0.0f; + } + + S alpha = min(0.999f, opac * exp(power)); + + if (alpha < 1.f / 255.f) { + continue; + } + + const S next_T = point_Ts[k] * (1.0f - alpha); + const S vis = alpha * point_Ts[k]; + point_alphas[k] += vis; + point_Ts[k] = next_T; + } + } + } + + if (inside){ + for (int k = 0; k < num_projected; k++){ + out_integrated_alphas[projected_ids[k]] = point_alphas[k]; + // // write colors + // for (int ch = 0; ch < CHANNELS; ch++) + // out_color_integrated[CHANNELS * projected_ids[k] + ch] = C[ch] + corner_Ts[0] * bg_color[ch];; + } + } + + } +} + +template +std::tuple call_kernel_with_dim( + // Point parameters + const torch::Tensor &points2d, // [C, N, 2] + const torch::Tensor &point_depths, // [C, N, 3] + // Gaussian parameters + const torch::Tensor &means2d, // [C, N, 2] or [nnz, 2] + const torch::Tensor &conics, // [C, N, 3] or [nnz, 3] + const torch::Tensor &colors, // [C, N, channels] or [nnz, channels] + const torch::Tensor &opacities, // [C, N] or [nnz] + const torch::Tensor &view2gaussians, // [C, N, 10] or [nnz, 10] + const torch::Tensor &Ks, // [C, 3, 3] + const at::optional &backgrounds, // [C, channels] + // image size + const uint32_t image_width, const uint32_t image_height, const uint32_t tile_size, + // intersections + const torch::Tensor &tile_offsets, // [C, tile_height, tile_width] + const torch::Tensor &flatten_ids, // [n_isects] + // points intersections + const torch::Tensor &point_tile_offsets, // [C, tile_height, tile_width] + const torch::Tensor &point_flatten_ids // [n_isects] +) { + DEVICE_GUARD(means2d); + CHECK_INPUT(means2d); + CHECK_INPUT(conics); + CHECK_INPUT(colors); + CHECK_INPUT(opacities); + CHECK_INPUT(view2gaussians); + CHECK_INPUT(Ks); + CHECK_INPUT(tile_offsets); + CHECK_INPUT(flatten_ids); + CHECK_INPUT(points2d); + CHECK_INPUT(point_depths); + CHECK_INPUT(point_tile_offsets); + CHECK_INPUT(point_flatten_ids); + if (backgrounds.has_value()) { + CHECK_INPUT(backgrounds.value()); + } + bool packed = means2d.dim() == 2; + + uint32_t C = tile_offsets.size(0); // number of cameras + uint32_t N = packed ? 0 : means2d.size(1); // number of gaussians + uint32_t channels = colors.size(-1); + uint32_t tile_height = tile_offsets.size(1); + uint32_t tile_width = tile_offsets.size(2); + uint32_t n_isects = flatten_ids.size(0); + + uint32_t PN = points2d.size(1); + uint32_t point_n_isects = point_flatten_ids.size(0); + // Each block covers a tile on the image. In total there are + // C * tile_height * tile_width blocks. + dim3 threads = {tile_size, tile_size, 1}; + dim3 blocks = {C, tile_height, tile_width}; + + // + 1 for depth and + 3 for normal + torch::Tensor renders = torch::zeros({C, image_height, image_width, channels + 1 + 3}, + means2d.options().dtype(torch::kFloat32)); + torch::Tensor alphas = torch::empty({C, image_height, image_width, 1}, + means2d.options().dtype(torch::kFloat32)); + torch::Tensor out_integrated_alphas = torch::full({C, PN}, 1.0, + means2d.options().dtype(torch::kFloat32)); + printf("out_integrated_alphas size: %d %d\n", out_integrated_alphas.size(0), out_integrated_alphas.size(1)); + // 1 for last_ids and 1 for max_contributor + torch::Tensor last_ids = torch::empty({C, image_height, image_width, 2}, + means2d.options().dtype(torch::kInt32)); + + at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream(); + const uint32_t shared_mem = + tile_size * tile_size * + (sizeof(int32_t) + sizeof(vec3) + sizeof(vec3) + sizeof(vec10) + sizeof(vec3)); + + // TODO: an optimization can be done by passing the actual number of channels into + // the kernel functions and avoid necessary global memory writes. This requires + // moving the channel padding from python to C side. + if (cudaFuncSetAttribute(integrate_to_points_fwd_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + shared_mem) != cudaSuccess) { + AT_ERROR("Failed to set maximum shared memory size (requested ", shared_mem, + " bytes), try lowering tile_size."); + } + integrate_to_points_fwd_kernel + <<>>( + C, N, n_isects, PN, point_n_isects, packed, + reinterpret_cast *>(points2d.data_ptr()), + point_depths.data_ptr(), + reinterpret_cast *>(means2d.data_ptr()), + reinterpret_cast *>(conics.data_ptr()), + colors.data_ptr(), opacities.data_ptr(), + reinterpret_cast *>(view2gaussians.data_ptr()), + Ks.data_ptr(), + backgrounds.has_value() ? backgrounds.value().data_ptr() : nullptr, + image_width, image_height, tile_size, tile_width, tile_height, + tile_offsets.data_ptr(), flatten_ids.data_ptr(), + point_tile_offsets.data_ptr(), point_flatten_ids.data_ptr(), + renders.data_ptr(), alphas.data_ptr(), + out_integrated_alphas.data_ptr(), + last_ids.data_ptr()); + + return std::make_tuple(renders, out_integrated_alphas); +} + +std::tuple integrate_to_points_fwd_tensor( + // Point parameters + const torch::Tensor &points2d, // [C, N, 2] + const torch::Tensor &point_depths, // [C, N, 3] + // Gaussian parameters + const torch::Tensor &means2d, // [C, N, 2] or [nnz, 2] + const torch::Tensor &conics, // [C, N, 3] or [nnz, 3] + const torch::Tensor &colors, // [C, N, channels] or [nnz, channels] + const torch::Tensor &opacities, // [C, N] or [nnz] + const torch::Tensor &view2gaussians, // [C, N, 10] or [nnz, 10] + const torch::Tensor &Ks, // [C, 3, 3] + const at::optional &backgrounds, // [C, channels] + // image size + const uint32_t image_width, const uint32_t image_height, const uint32_t tile_size, + // intersections + const torch::Tensor &tile_offsets, // [C, tile_height, tile_width] + const torch::Tensor &flatten_ids, // [n_isects] + // points intersections + const torch::Tensor &point_tile_offsets, // [C, tile_height, tile_width] + const torch::Tensor &point_flatten_ids // [n_isects] +) { + CHECK_INPUT(colors); + uint32_t channels = colors.size(-1); + +#define __GS__CALL_(N) \ + case N: \ + return call_kernel_with_dim(points2d, point_depths, \ + means2d, conics, colors, opacities, \ + view2gaussians, Ks, \ + backgrounds, image_width, image_height, \ + tile_size, tile_offsets, flatten_ids, \ + point_tile_offsets, point_flatten_ids); + + // TODO: an optimization can be done by passing the actual number of channels into + // the kernel functions and avoid necessary global memory writes. This requires + // moving the channel padding from python to C side. + switch (channels) { + __GS__CALL_(1) + __GS__CALL_(2) + __GS__CALL_(3) + __GS__CALL_(4) + __GS__CALL_(5) + __GS__CALL_(8) + __GS__CALL_(9) + __GS__CALL_(16) + __GS__CALL_(17) + __GS__CALL_(32) + __GS__CALL_(33) + __GS__CALL_(64) + __GS__CALL_(65) + __GS__CALL_(128) + __GS__CALL_(129) + __GS__CALL_(256) + __GS__CALL_(257) + __GS__CALL_(512) + __GS__CALL_(513) + default: + AT_ERROR("Unsupported number of channels: ", channels); + } +} diff --git a/gsplat/cuda/csrc/isect_tiles.cu b/gsplat/cuda/csrc/isect_tiles.cu index f91677ac4..2dc4123ea 100644 --- a/gsplat/cuda/csrc/isect_tiles.cu +++ b/gsplat/cuda/csrc/isect_tiles.cu @@ -58,7 +58,7 @@ __global__ void isect_tiles( // tile_min is inclusive, tile_max is exclusive uint2 tile_min, tile_max; - tile_min.x = min(max(0, (uint32_t)floor(tile_x - tile_radius)), tile_width); + tile_min.x = min(max(0, (uint32_t)floor(tile_x - tile_radius)), tile_width); tile_min.y = min(max(0, (uint32_t)floor(tile_y - tile_radius)), tile_height); tile_max.x = min(max(0, (uint32_t)ceil(tile_x + tile_radius)), tile_width); tile_max.y = min(max(0, (uint32_t)ceil(tile_y + tile_radius)), tile_height); @@ -231,6 +231,187 @@ isect_tiles_tensor(const torch::Tensor &means2d, // [C, N, 2] or [nnz, 2] } } +template +__global__ void points_isect_tiles( + // if the data is [C, N, ...] or [nnz, ...] (packed) + const bool packed, + // parallelize over C * N, only used if packed is False + const uint32_t C, const uint32_t N, + // parallelize over nnz, only used if packed is True + const uint32_t nnz, + const int64_t *__restrict__ camera_ids, // [nnz] optional + const int64_t *__restrict__ gaussian_ids, // [nnz] optional + // data + const T *__restrict__ means2d, // [C, N, 2] or [nnz, 2] + const int32_t *__restrict__ radii, // [C, N] or [nnz] + const T *__restrict__ depths, // [C, N] or [nnz] + const int64_t *__restrict__ cum_tiles_per_gauss, // [C, N] or [nnz] + const uint32_t tile_size, const uint32_t tile_width, const uint32_t tile_height, + const uint32_t tile_n_bits, + int64_t *__restrict__ isect_ids, // [n_isects] + int32_t *__restrict__ flatten_ids // [n_isects] +) { + // For now we'll upcast float16 and bfloat16 to float32 + using OpT = typename OpType::type; + + // parallelize over C * N. + uint32_t idx = cg::this_grid().thread_rank(); + bool first_pass = cum_tiles_per_gauss == nullptr; + if (idx >= (packed ? nnz : C * N)) { + return; + } + + const OpT radius = radii[idx]; + if (radius <= 0) { + return; + } + + vec2 mean2d = glm::make_vec2(means2d + 2*idx); + + OpT tile_x = mean2d.x / static_cast(tile_size); + OpT tile_y = mean2d.y / static_cast(tile_size); + + uint32_t x = min(max(0, (uint32_t)floor(tile_x)), tile_width); + uint32_t y = min(max(0, (uint32_t)floor(tile_y)), tile_height); + + int64_t cid; // camera id + if (packed) { + // parallelize over nnz + cid = camera_ids[idx]; + // gid = gaussian_ids[idx]; + } else { + // parallelize over C * N + cid = idx / N; + // gid = idx % N; + } + const int64_t cid_enc = cid << (32 + tile_n_bits); + + int64_t depth_id_enc = (int64_t) * (int32_t *)&(depths[idx]); + int64_t cur_idx = (idx == 0) ? 0 : cum_tiles_per_gauss[idx - 1]; + int64_t tile_id = y * tile_width + x; + // e.g. tile_n_bits = 22: + // camera id (10 bits) | tile id (22 bits) | depth (32 bits) + isect_ids[cur_idx] = cid_enc | (tile_id << 32) | depth_id_enc; + // the flatten index in [C * N] or [nnz] + flatten_ids[cur_idx] = static_cast(idx); + +} + +std::tuple +points_isect_tiles_tensor(const torch::Tensor &means2d, // [C, N, 2] or [nnz, 2] + const torch::Tensor &radii, // [C, N] or [nnz] + const torch::Tensor &depths, // [C, N] or [nnz] + const at::optional &camera_ids, // [nnz] + const at::optional &gaussian_ids, // [nnz] + const uint32_t C, const uint32_t tile_size, + const uint32_t tile_width, const uint32_t tile_height, + const bool sort, const bool double_buffer) { + DEVICE_GUARD(means2d); + CHECK_INPUT(means2d); + CHECK_INPUT(radii); + CHECK_INPUT(depths); + if (camera_ids.has_value()) { + CHECK_INPUT(camera_ids.value()); + } + if (gaussian_ids.has_value()) { + CHECK_INPUT(gaussian_ids.value()); + } + bool packed = means2d.dim() == 2; + + uint32_t N = 0, nnz = 0, total_elems = 0; + int64_t *camera_ids_ptr = nullptr; + int64_t *gaussian_ids_ptr = nullptr; + if (packed) { + nnz = means2d.size(0); + total_elems = nnz; + TORCH_CHECK(camera_ids.has_value() && gaussian_ids.has_value(), + "When packed is set, camera_ids and gaussian_ids must be provided."); + camera_ids_ptr = camera_ids.value().data_ptr(); + gaussian_ids_ptr = gaussian_ids.value().data_ptr(); + } else { + N = means2d.size(1); // number of gaussians + total_elems = C * N; + } + + uint32_t n_tiles = tile_width * tile_height; + at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream(); + + // the number of bits needed to encode the camera id and tile id + // Note: std::bit_width requires C++20 + // uint32_t tile_n_bits = std::bit_width(n_tiles); + // uint32_t cam_n_bits = std::bit_width(C); + uint32_t tile_n_bits = (uint32_t)floor(log2(n_tiles)) + 1; + uint32_t cam_n_bits = (uint32_t)floor(log2(C)) + 1; + // the first 32 bits are used for the camera id and tile id altogether, so + // check if we have enough bits for them. + assert(tile_n_bits + cam_n_bits <= 32); + + torch::Tensor cum_tiles_per_gauss = torch::cumsum(radii.view({-1}), 0); + + int64_t n_isects = cum_tiles_per_gauss[-1].item(); + + torch::Tensor isect_ids = + torch::empty({n_isects}, depths.options().dtype(torch::kInt64)); + torch::Tensor flatten_ids = + torch::empty({n_isects}, depths.options().dtype(torch::kInt32)); + if (n_isects) { + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, means2d.scalar_type(), "isect_tiles_n_isects", [&]() { + points_isect_tiles<<<(total_elems + N_THREADS - 1) / N_THREADS, N_THREADS, 0, + stream>>>( + packed, C, N, nnz, camera_ids_ptr, gaussian_ids_ptr, + reinterpret_cast(means2d.data_ptr()), + radii.data_ptr(), depths.data_ptr(), + cum_tiles_per_gauss.data_ptr(), tile_size, tile_width, tile_height, + tile_n_bits, isect_ids.data_ptr(), + flatten_ids.data_ptr()); + }); + } + + // optionally sort the Gaussians by isect_ids + if (n_isects && sort) { + torch::Tensor isect_ids_sorted = torch::empty_like(isect_ids); + torch::Tensor flatten_ids_sorted = torch::empty_like(flatten_ids); + + // https://nvidia.github.io/cccl/cub/api/structcub_1_1DeviceRadixSort.html + // DoubleBuffer reduce the auxiliary memory usage from O(N+P) to O(P) + if (double_buffer) { + // Create a set of DoubleBuffers to wrap pairs of device pointers + cub::DoubleBuffer d_keys(isect_ids.data_ptr(), + isect_ids_sorted.data_ptr()); + cub::DoubleBuffer d_values(flatten_ids.data_ptr(), + flatten_ids_sorted.data_ptr()); + CUB_WRAPPER(cub::DeviceRadixSort::SortPairs, d_keys, d_values, n_isects, 0, + 32 + tile_n_bits + cam_n_bits, stream); + switch (d_keys.selector) { + case 0: // sorted items are stored in isect_ids + isect_ids_sorted = isect_ids; + break; + case 1: // sorted items are stored in isect_ids_sorted + break; + } + switch (d_values.selector) { + case 0: // sorted items are stored in flatten_ids + flatten_ids_sorted = flatten_ids; + break; + case 1: // sorted items are stored in flatten_ids_sorted + break; + } + // printf("DoubleBuffer d_keys selector: %d\n", d_keys.selector); + // printf("DoubleBuffer d_values selector: %d\n", + // d_values.selector); + } else { + CUB_WRAPPER(cub::DeviceRadixSort::SortPairs, isect_ids.data_ptr(), + isect_ids_sorted.data_ptr(), + flatten_ids.data_ptr(), + flatten_ids_sorted.data_ptr(), n_isects, 0, + 32 + tile_n_bits + cam_n_bits, stream); + } + return std::make_tuple(isect_ids_sorted, flatten_ids_sorted); + } else { + return std::make_tuple(isect_ids, flatten_ids); + } +} + __global__ void isect_offset_encode(const uint32_t n_isects, const int64_t *__restrict__ isect_ids, const uint32_t C, const uint32_t n_tiles, diff --git a/gsplat/cuda/csrc/project_points_fwd.cu b/gsplat/cuda/csrc/project_points_fwd.cu new file mode 100644 index 000000000..9feb3cf1c --- /dev/null +++ b/gsplat/cuda/csrc/project_points_fwd.cu @@ -0,0 +1,118 @@ +#include "bindings.h" +#include "helpers.cuh" +#include "utils.cuh" + +#include +#include +#include +#include +#include + +namespace cg = cooperative_groups; + + +/**************************************************************************** + * Projection of Points (Single Batch) Forward Pass + ****************************************************************************/ + +template +__global__ void +project_points_fwd_kernel(const uint32_t C, const uint32_t N, + const T *__restrict__ means, // [N, 3] + const T *__restrict__ viewmats, // [C, 4, 4] + const T *__restrict__ Ks, // [C, 3, 3] + const int32_t image_width, const int32_t image_height, + const T near_plane, const T far_plane, + // outputs + int32_t *__restrict__ radii, // [C, N] + T *__restrict__ means2d, // [C, N, 2] + T *__restrict__ depths // [C, N] +) { + // parallelize over C * N. + uint32_t idx = cg::this_grid().thread_rank(); + if (idx >= C * N) { + return; + } + const uint32_t cid = idx / N; // camera id + const uint32_t gid = idx % N; // gaussian id + + // shift pointers to the current camera and gaussian + means += gid * 3; + viewmats += cid * 16; + Ks += cid * 9; + + // glm is column-major but input is row-major + mat3 R = mat3(viewmats[0], viewmats[4], viewmats[8], // 1st column + viewmats[1], viewmats[5], viewmats[9], // 2nd column + viewmats[2], viewmats[6], viewmats[10] // 3rd column + ); + vec3 t = vec3(viewmats[3], viewmats[7], viewmats[11]); + + // transform Gaussian center to camera space + vec3 mean_c; + pos_world_to_cam(R, t, glm::make_vec3(means), mean_c); + if (mean_c.z < near_plane || mean_c.z > far_plane) { + radii[idx] = 0; + return; + } + + // project the point to image plane + vec2 mean2d; + + const T fx = Ks[0]; + const T fy = Ks[4]; + const T cx = Ks[2]; + const T cy = Ks[5]; + + T x = mean_c[0], y = mean_c[1], z = mean_c[2]; + T rz = 1.f / z; + mean2d = vec2({fx * x * rz + cx, fy * y * rz + cy}); + + // mask out gaussians outside the image region + if (mean2d.x <= 0 || mean2d.x >= image_width || + mean2d.y <= 0 || mean2d.y >= image_height) { + radii[idx] = 0; + return; + } + + // write to outputs + radii[idx] = 1; + means2d[idx * 2] = mean2d.x; + means2d[idx * 2 + 1] = mean2d.y; + depths[idx] = mean_c.z; +} + + +std::tuple +project_points_fwd_tensor( + const torch::Tensor &means, // [N, 3] + const torch::Tensor &viewmats, // [C, 4, 4] + const torch::Tensor &Ks, // [C, 3, 3] + const uint32_t image_width, const uint32_t image_height, + const float near_plane, const float far_plane) { + DEVICE_GUARD(means); + CHECK_INPUT(means); + CHECK_INPUT(viewmats); + CHECK_INPUT(Ks); + + uint32_t N = means.size(0); // number of gaussians + uint32_t C = viewmats.size(0); // number of cameras + at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream(); + + torch::Tensor radii = torch::empty({C, N}, means.options().dtype(torch::kInt32)); + torch::Tensor means2d = torch::empty({C, N, 2}, means.options()); + torch::Tensor depths = torch::empty({C, N}, means.options()); + + if (C && N) { + project_points_fwd_kernel<<<(C * N + N_THREADS - 1) / N_THREADS, N_THREADS, 0, stream>>>( + C, N, means.data_ptr(), + viewmats.data_ptr(), + Ks.data_ptr(), + image_width, image_height, + near_plane, far_plane, + radii.data_ptr(), + means2d.data_ptr(), + depths.data_ptr()); + } + return std::make_tuple(radii, means2d, depths); +} diff --git a/gsplat/cuda/csrc/raytracing_to_pixels_bwd.cu b/gsplat/cuda/csrc/raytracing_to_pixels_bwd.cu new file mode 100644 index 000000000..99ae79f0f --- /dev/null +++ b/gsplat/cuda/csrc/raytracing_to_pixels_bwd.cu @@ -0,0 +1,560 @@ +#include "bindings.h" +#include "helpers.cuh" +#include "types.cuh" +#include +#include +#include + +namespace cg = cooperative_groups; + +/**************************************************************************** + * Rasterization to Pixels Backward Pass + ****************************************************************************/ + +template +__global__ void raytracing_to_pixels_bwd_kernel( + const uint32_t C, const uint32_t N, const uint32_t n_isects, const bool packed, + // fwd inputs + const vec2 *__restrict__ means2d, // [C, N, 2] or [nnz, 2] + const vec3 *__restrict__ conics, // [C, N, 3] or [nnz, 3] + const S *__restrict__ colors, // [C, N, COLOR_DIM] or [nnz, COLOR_DIM] + const S *__restrict__ opacities, // [C, N] or [nnz] + const vec10 *__restrict__ view2gaussians, // [C, N, 10] or [nnz, 10] + const S *__restrict__ Ks, // [C, 3, 3] + const S *__restrict__ backgrounds, // [C, COLOR_DIM] or [nnz, COLOR_DIM] + const uint32_t image_width, const uint32_t image_height, const uint32_t tile_size, + const uint32_t tile_width, const uint32_t tile_height, + const int32_t *__restrict__ tile_offsets, // [C, tile_height, tile_width] + const int32_t *__restrict__ flatten_ids, // [n_isects] + // fwd outputs + const S *__restrict__ render_alphas, // [C, image_height, image_width, 1] + const int32_t *__restrict__ last_ids, // [C, image_height, image_width] + // grad outputs + const S *__restrict__ v_render_colors, // [C, image_height, image_width, + // COLOR_DIM] + const S *__restrict__ v_render_alphas, // [C, image_height, image_width, 1] + // grad inputs + vec2 *__restrict__ v_means2d_abs, // [C, N, 2] or [nnz, 2] + vec2 *__restrict__ v_means2d, // [C, N, 2] or [nnz, 2] + vec3 *__restrict__ v_conics, // [C, N, 3] or [nnz, 3] + S *__restrict__ v_colors, // [C, N, COLOR_DIM] or [nnz, COLOR_DIM] + S *__restrict__ v_opacities, // [C, N] or [nnz] + S *__restrict__ v_view2gaussians // [C, N, 10] or [nnz, 10] +) { + auto block = cg::this_thread_block(); + uint32_t camera_id = block.group_index().x; + uint32_t tile_id = block.group_index().y * tile_width + block.group_index().z; + uint32_t i = block.group_index().y * tile_size + block.thread_index().y; + uint32_t j = block.group_index().z * tile_size + block.thread_index().x; + + tile_offsets += camera_id * tile_height * tile_width; + render_alphas += camera_id * image_height * image_width; + last_ids += camera_id * image_height * image_width * 2; + v_render_colors += camera_id * image_height * image_width * (COLOR_DIM + 1 + 3); + v_render_alphas += camera_id * image_height * image_width; + Ks += camera_id * 9; + if (backgrounds != nullptr) { + backgrounds += camera_id * COLOR_DIM; + } + + const S px = (S)j + 0.5f; + const S py = (S)i + 0.5f; + // clamp this value to the last pixel + const int32_t pix_id = min(i * image_width + j, image_width * image_height - 1); + + const S focal_x = Ks[0]; + const S focal_y = Ks[4]; + const S cx = Ks[2]; + const S cy = Ks[5]; + const vec3 ray = {(px - cx) / focal_x, (py - cy) / focal_y, 1.0}; + + // keep not rasterizing threads around for reading data + bool inside = (i < image_height && j < image_width); + + // have all threads in tile process the same gaussians in batches + // first collect gaussians between range.x and range.y in batches + // which gaussians to look through in this tile + int32_t range_start = tile_offsets[tile_id]; + int32_t range_end = + (camera_id == C - 1) && (tile_id == tile_width * tile_height - 1) + ? n_isects + : tile_offsets[tile_id + 1]; + const uint32_t block_size = block.size(); + const uint32_t num_batches = + (range_end - range_start + block_size - 1) / block_size; + + extern __shared__ int s[]; + int32_t *id_batch = (int32_t *)s; // [block_size] + vec3 *xy_opacity_batch = + reinterpret_cast *>(&id_batch[block_size]); // [block_size] + vec3 *conic_batch = + reinterpret_cast *>(&xy_opacity_batch[block_size]); // [block_size] + vec10 *view2gaussian_batch = + reinterpret_cast *>(&conic_batch[block_size]); // [block_size] + S *rgbs_batch = (S *)&view2gaussian_batch[block_size]; // [block_size * COLOR_DIM] + + + // this is the T AFTER the last gaussian in this pixel + S T_final = 1.0f - render_alphas[pix_id]; + S T = T_final; + // the contribution from gaussians behind the current one + S buffer[COLOR_DIM] = {0.f}; + S buffer_normal[3] = {0.f}; + // index of last gaussian to contribute to this pixel + const int32_t bin_final = inside ? last_ids[pix_id * 2] : 0; + const int32_t bin_max = inside? last_ids[pix_id * 2 + 1] : 0; + // df/d_out for this pixel + S v_render_c[COLOR_DIM]; + PRAGMA_UNROLL + for (uint32_t k = 0; k < COLOR_DIM; ++k) { + v_render_c[k] = v_render_colors[pix_id * (COLOR_DIM + 1 + 3) + k]; + } + const S v_render_a = v_render_alphas[pix_id]; + // gradient for normal and depth + S v_render_normal[3]; + PRAGMA_UNROLL + for (uint32_t k = 0; k < 3; ++k) { + v_render_normal[k] = v_render_colors[pix_id * (COLOR_DIM + 1 + 3) + COLOR_DIM + k]; + } + S v_render_depth = v_render_colors[pix_id * (COLOR_DIM + 1 + 3) + COLOR_DIM + 3]; + + // collect and process batches of gaussians + // each thread loads one gaussian at a time before rasterizing + const uint32_t tr = block.thread_rank(); + cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block); + const int32_t warp_bin_final = cg::reduce(warp, bin_final, cg::greater()); + for (uint32_t b = 0; b < num_batches; ++b) { + // resync all threads before writing next batch of shared mem + block.sync(); + + // each thread fetch 1 gaussian from back to front + // 0 index will be furthest back in batch + // index of gaussian to load + // batch end is the index of the last gaussian in the batch + // These values can be negative so must be int32 instead of uint32 + const int32_t batch_end = range_end - 1 - block_size * b; + const int32_t batch_size = min(block_size, batch_end + 1 - range_start); + const int32_t idx = batch_end - tr; + if (idx >= range_start) { + int32_t g = flatten_ids[idx]; // flatten index in [C * N] or [nnz] + id_batch[tr] = g; + const vec2 xy = means2d[g]; + const S opac = opacities[g]; + xy_opacity_batch[tr] = {xy.x, xy.y, opac}; + conic_batch[tr] = conics[g]; + view2gaussian_batch[tr] = view2gaussians[g]; + PRAGMA_UNROLL + for (uint32_t k = 0; k < COLOR_DIM; ++k) { + rgbs_batch[tr * COLOR_DIM + k] = colors[g * COLOR_DIM + k]; + } + } + // wait for other threads to collect the gaussians in batch + block.sync(); + // process gaussians in the current batch for this pixel + // 0 index is the furthest back gaussian in the batch + for (uint32_t t = max(0, batch_end - warp_bin_final); t < batch_size; ++t) { + bool valid = inside; + if (batch_end - t > bin_final) { + valid = 0; + } + S alpha; + S opac; + vec2 delta; + vec3 conic; + const vec10 view2gaussian = view2gaussian_batch[t]; + vec3 normal; + S vis; + S AA, BB, CC, depth, min_value, power; + + if (valid) { + conic = conic_batch[t]; + vec3 xy_opac = xy_opacity_batch[t]; + + opac = xy_opac.z; + delta = {xy_opac.x - px, xy_opac.y - py}; + // S sigma = + // 0.5f * (conic.x * delta.x * delta.x + conic.z * delta.y * delta.y) + + // conic.y * delta.x * delta.y; + // vis = __expf(-sigma); + // alpha = min(0.999f, opac * vis); + + normal = { + view2gaussian[0] * ray.x + view2gaussian[1] * ray.y + view2gaussian[2], + view2gaussian[1] * ray.x + view2gaussian[3] * ray.y + view2gaussian[4], + view2gaussian[2] * ray.x + view2gaussian[4] * ray.y + view2gaussian[5] + }; + + // use AA, BB, CC so that the name is unique + AA = ray.x * normal[0] + ray.y * normal[1] + normal[2]; + BB = 2 * (view2gaussian[6] * ray.x + view2gaussian[7] * ray.y + view2gaussian[8]); + CC = view2gaussian[9]; + + // t is the depth of the gaussian + depth = -BB/(2*AA); + //TODO take near plane as input + #define NEAR_PLANE 0.01f + // depth must be positive otherwise it is not valid and we skip it + if (depth <= NEAR_PLANE) + valid = false; + + // the scale of the gaussian is 1.f / sqrt(AA) + min_value = -(BB/AA) * (BB/4.) + CC; + + power = -0.5f * min_value; + if (power > 0.0f){ + power = 0.0f; + } + + vis = exp(power); + alpha = min(0.999f, opac * vis); + + if (alpha < 1.f / 255.f) { + valid = false; + } + } + + // if all threads are inactive in this warp, skip this loop + if (!warp.any(valid)) { + continue; + } + S v_rgb_local[COLOR_DIM] = {0.f}; + S dL_dnormal_normalized[3] = {0.f}; + vec3 v_conic_local = {0.f, 0.f, 0.f}; + vec2 v_xy_local = {0.f, 0.f}; + vec2 v_xy_abs_local = {0.f, 0.f}; + S v_opacity_local = 0.f; + S v_view2gaussian_local[10] = {0.f}; + // initialize everything to 0, only set if the lane is valid + if (valid) { + + const S length = sqrt(normal[0] * normal[0] + normal[1] * normal[1] + normal[2] * normal[2] + 1e-7); + const vec3 normal_normalized = { -normal[0] / length, -normal[1] / length, -normal[2] / length }; + + // compute the current T for this gaussian + S ra = 1.0f / (1.0f - alpha); + T *= ra; + // update v_rgb for this gaussian + const S fac = alpha * T; + PRAGMA_UNROLL + for (uint32_t k = 0; k < COLOR_DIM; ++k) { + v_rgb_local[k] = fac * v_render_c[k]; + } + // contribution from normal and depth + PRAGMA_UNROLL + for (uint32_t k = 0; k < 3; ++k) { + dL_dnormal_normalized[k] = fac * v_render_normal[k]; + } + // contribution from this pixel + S v_alpha = 0.f; + for (uint32_t k = 0; k < COLOR_DIM; ++k) { + v_alpha += (rgbs_batch[t * COLOR_DIM + k] * T - buffer[k] * ra) * + v_render_c[k]; + } + // contribution from this pixel's normal + for (uint32_t k = 0; k < 3; ++k) { + v_alpha += (normal_normalized[k] * T - buffer_normal[k] * ra) * + v_render_normal[k]; + } + + v_alpha += T_final * ra * v_render_a; + // contribution from background pixel + if (backgrounds != nullptr) { + S accum = 0.f; + PRAGMA_UNROLL + for (uint32_t k = 0; k < COLOR_DIM; ++k) { + accum += backgrounds[k] * v_render_c[k]; + } + v_alpha += -T_final * ra * accum; + } + // here is different to 3DGS, in 3DGS the gradient is computed even if opac * vis > 0.999f + if (opac * vis <= 0.999f) { + const S dL_dG = opac * v_alpha; + const S v_sigma = -opac * vis * v_alpha; + v_xy_local = {v_sigma * (conic.x * delta.x + conic.y * delta.y), + v_sigma * (conic.y * delta.x + conic.z * delta.y)}; + + if (v_means2d_abs != nullptr) { + v_xy_abs_local = {abs(v_xy_local.x), abs(v_xy_local.y)}; + } + v_opacity_local = vis * v_alpha; + + // gradient of depth + S dL_dt = 0.0f; + vec3 dL_dnormal = {0.0f, 0.0f, 0.0f}; + // float length = sqrt(normal[0] * normal[0] + normal[1] * normal[1] + normal[2] * normal[2] + 1e-7); + // const float normal_normalized[3] = { -normal[0] / length, -normal[1] / length, -normal[2] / length}; + S dL_dlength = (dL_dnormal_normalized[0] * normal[0] + dL_dnormal_normalized[1] * normal[1] + dL_dnormal_normalized[2] * normal[2]); + dL_dlength *= 1.f / (length * length); + dL_dnormal += vec3( + (-dL_dnormal_normalized[0] + dL_dlength * normal[0]) / length, + (-dL_dnormal_normalized[1] + dL_dlength * normal[1]) / length, + (-dL_dnormal_normalized[2] + dL_dlength * normal[2]) / length + ); + + if (batch_end - t == bin_max){ + dL_dt += v_render_depth; + } + + // vis = exp(power); + const S dG_dpower = vis; + const S dL_dpower = dL_dG * dG_dpower; + + // // float power = -0.5f * min_value; + const S dL_dmin_value = dL_dpower * -0.5f; + // float min_value = -(BB*BB)/(4*AA) + CC; + // const float dL_dA = dL_dmin_value * (BB*BB)/4 * 1. / (AA*AA); + S dL_dA = dL_dmin_value * (BB / AA) * (BB / AA) / 4.f; + S dL_dB = dL_dmin_value * -BB / (2 *AA); + S dL_dC = dL_dmin_value * 1.0f; + // from depth = -BB/(2*AA) + dL_dA += dL_dt * BB / (2 * AA * AA); + dL_dB += dL_dt * -1.f / (2 * AA); + + // const float normal[3] = { view2gaussian_j[0] * ray.x + view2gaussian_j[1] * ray.y + view2gaussian_j[2], + // view2gaussian_j[1] * ray.x + view2gaussian_j[3] * ray.y + view2gaussian_j[4], + // view2gaussian_j[2] * ray.x + view2gaussian_j[4] * ray.y + view2gaussian_j[5]}; + + // use AA, BB, CC so that the name is unique + // float AA = ray.x * normal[0] + ray.y * normal[1] + normal[2]; + // float BB = 2 * (view2gaussian_j[6] * ray_point.x + view2gaussian_j[7] * ray_point.y + view2gaussian_j[8]); + // float CC = view2gaussian_j[9]; + dL_dnormal[0] += dL_dA * ray.x; + dL_dnormal[1] += dL_dA * ray.y; + dL_dnormal[2] += dL_dA; + + // write the gradients to global memory directly + // atomicAdd(&(dL_dview2gaussian[global_id * 10 + 0]), dL_dnormal[0] * ray.x); + // atomicAdd(&(dL_dview2gaussian[global_id * 10 + 1]), dL_dnormal[0] * ray.y + dL_dnormal[1] * ray.x); + // atomicAdd(&(dL_dview2gaussian[global_id * 10 + 2]), dL_dnormal[0] + dL_dnormal[2] * ray.x); + // atomicAdd(&(dL_dview2gaussian[global_id * 10 + 3]), dL_dnormal[1] * ray.y); + // atomicAdd(&(dL_dview2gaussian[global_id * 10 + 4]), dL_dnormal[1] + dL_dnormal[2] * ray.y); + // atomicAdd(&(dL_dview2gaussian[global_id * 10 + 5]), dL_dnormal[2]); + // atomicAdd(&(dL_dview2gaussian[global_id * 10 + 6]), dL_dB * 2 * ray.x); + // atomicAdd(&(dL_dview2gaussian[global_id * 10 + 7]), dL_dB * 2 * ray.y); + // atomicAdd(&(dL_dview2gaussian[global_id * 10 + 8]), dL_dB * 2); + // atomicAdd(&(dL_dview2gaussian[global_id * 10 + 9]), dL_dC); + v_view2gaussian_local[0] = dL_dnormal[0] * ray.x; + v_view2gaussian_local[1] = dL_dnormal[0] * ray.y + dL_dnormal[1] * ray.x; + v_view2gaussian_local[2] = dL_dnormal[0] + dL_dnormal[2] * ray.x; + v_view2gaussian_local[3] = dL_dnormal[1] * ray.y; + v_view2gaussian_local[4] = dL_dnormal[1] + dL_dnormal[2] * ray.y; + v_view2gaussian_local[5] = dL_dnormal[2]; + v_view2gaussian_local[6] = dL_dB * 2 * ray.x; + v_view2gaussian_local[7] = dL_dB * 2 * ray.y; + v_view2gaussian_local[8] = dL_dB * 2; + v_view2gaussian_local[9] = dL_dC; + } + + PRAGMA_UNROLL + for (uint32_t k = 0; k < COLOR_DIM; ++k) { + buffer[k] += rgbs_batch[t * COLOR_DIM + k] * fac; + } + PRAGMA_UNROLL + for (uint32_t k = 0; k < 3; ++k) { + buffer_normal[k] += normal_normalized[k] * fac; + } + } + warpSum(v_rgb_local, warp); + warpSum<10, S>(v_view2gaussian_local, warp); + warpSum(v_conic_local, warp); + warpSum(v_xy_local, warp); + if (v_means2d_abs != nullptr) { + warpSum(v_xy_abs_local, warp); + } + warpSum(v_opacity_local, warp); + if (warp.thread_rank() == 0) { + int32_t g = id_batch[t]; // flatten index in [C * N] or [nnz] + S *v_rgb_ptr = (S *)(v_colors) + COLOR_DIM * g; + PRAGMA_UNROLL + for (uint32_t k = 0; k < COLOR_DIM; ++k) { + gpuAtomicAdd(v_rgb_ptr + k, v_rgb_local[k]); + } + + S *v_conic_ptr = (S *)(v_conics) + 3 * g; + gpuAtomicAdd(v_conic_ptr, v_conic_local.x); + gpuAtomicAdd(v_conic_ptr + 1, v_conic_local.y); + gpuAtomicAdd(v_conic_ptr + 2, v_conic_local.z); + + S *v_xy_ptr = (S *)(v_means2d) + 2 * g; + gpuAtomicAdd(v_xy_ptr, v_xy_local.x); + gpuAtomicAdd(v_xy_ptr + 1, v_xy_local.y); + + if (v_means2d_abs != nullptr) { + S *v_xy_abs_ptr = (S *)(v_means2d_abs) + 2 * g; + gpuAtomicAdd(v_xy_abs_ptr, v_xy_abs_local.x); + gpuAtomicAdd(v_xy_abs_ptr + 1, v_xy_abs_local.y); + } + + gpuAtomicAdd(v_opacities + g, v_opacity_local); + PRAGMA_UNROLL + for (uint32_t k = 0; k < 10; ++k) { + gpuAtomicAdd((S *)(v_view2gaussians) + 10 * g + k, v_view2gaussian_local[k]); + } + } + } + } +} + +template +std::tuple +call_kernel_with_dim( + // Gaussian parameters + const torch::Tensor &means2d, // [C, N, 2] or [nnz, 2] + const torch::Tensor &conics, // [C, N, 3] or [nnz, 3] + const torch::Tensor &colors, // [C, N, 3] or [nnz, 3] + const torch::Tensor &opacities, // [C, N] or [nnz] + const torch::Tensor &view2gaussians, // [C, N, 10] or [nnz, 10] + const torch::Tensor &Ks, // [C, 3, 3] + const at::optional &backgrounds, // [C, 3] + // image size + const uint32_t image_width, const uint32_t image_height, const uint32_t tile_size, + // intersections + const torch::Tensor &tile_offsets, // [C, tile_height, tile_width] + const torch::Tensor &flatten_ids, // [n_isects] + // forward outputs + const torch::Tensor &render_alphas, // [C, image_height, image_width, 1] + const torch::Tensor &last_ids, // [C, image_height, image_width] + // gradients of outputs + const torch::Tensor &v_render_colors, // [C, image_height, image_width, 3] + const torch::Tensor &v_render_alphas, // [C, image_height, image_width, 1] + // options + bool absgrad) { + + DEVICE_GUARD(means2d); + CHECK_INPUT(means2d); + CHECK_INPUT(conics); + CHECK_INPUT(colors); + CHECK_INPUT(opacities); + CHECK_INPUT(view2gaussians); + CHECK_INPUT(Ks); + CHECK_INPUT(tile_offsets); + CHECK_INPUT(flatten_ids); + CHECK_INPUT(render_alphas); + CHECK_INPUT(last_ids); + CHECK_INPUT(v_render_colors); + CHECK_INPUT(v_render_alphas); + if (backgrounds.has_value()) { + CHECK_INPUT(backgrounds.value()); + } + + bool packed = means2d.dim() == 2; + + uint32_t C = tile_offsets.size(0); // number of cameras + uint32_t N = packed ? 0 : means2d.size(1); // number of gaussians + uint32_t n_isects = flatten_ids.size(0); + uint32_t COLOR_DIM = colors.size(-1); + uint32_t tile_height = tile_offsets.size(1); + uint32_t tile_width = tile_offsets.size(2); + + // Each block covers a tile on the image. In total there are + // C * tile_height * tile_width blocks. + dim3 threads = {tile_size, tile_size, 1}; + dim3 blocks = {C, tile_height, tile_width}; + + torch::Tensor v_means2d = torch::zeros_like(means2d); + torch::Tensor v_conics = torch::zeros_like(conics); + torch::Tensor v_colors = torch::zeros_like(colors); + torch::Tensor v_opacities = torch::zeros_like(opacities); + torch::Tensor v_means2d_abs; + torch::Tensor v_view2gaussians = torch::zeros_like(view2gaussians); + if (absgrad) { + v_means2d_abs = torch::zeros_like(means2d); + } + + if (n_isects) { + const uint32_t shared_mem = tile_size * tile_size * + (sizeof(int32_t) + sizeof(vec3) + + sizeof(vec3) + sizeof(vec10) + sizeof(float) * COLOR_DIM); + at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream(); + + if (cudaFuncSetAttribute(raytracing_to_pixels_bwd_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + shared_mem) != cudaSuccess) { + AT_ERROR("Failed to set maximum shared memory size (requested ", shared_mem, + " bytes), try lowering tile_size."); + } + raytracing_to_pixels_bwd_kernel + <<>>( + C, N, n_isects, packed, + reinterpret_cast *>(means2d.data_ptr()), + reinterpret_cast *>(conics.data_ptr()), + colors.data_ptr(), opacities.data_ptr(), + reinterpret_cast *>(view2gaussians.data_ptr()), + Ks.data_ptr(), + backgrounds.has_value() ? backgrounds.value().data_ptr() + : nullptr, + image_width, image_height, tile_size, tile_width, tile_height, + tile_offsets.data_ptr(), flatten_ids.data_ptr(), + render_alphas.data_ptr(), last_ids.data_ptr(), + v_render_colors.data_ptr(), v_render_alphas.data_ptr(), + absgrad + ? reinterpret_cast *>(v_means2d_abs.data_ptr()) + : nullptr, + reinterpret_cast *>(v_means2d.data_ptr()), + reinterpret_cast *>(v_conics.data_ptr()), + v_colors.data_ptr(), v_opacities.data_ptr(), + v_view2gaussians.data_ptr()); + } + + return std::make_tuple(v_means2d_abs, v_means2d, v_conics, v_colors, v_opacities, v_view2gaussians); +} + +std::tuple +raytracing_to_pixels_bwd_tensor( + // Gaussian parameters + const torch::Tensor &means2d, // [C, N, 2] or [nnz, 2] + const torch::Tensor &conics, // [C, N, 3] or [nnz, 3] + const torch::Tensor &colors, // [C, N, 3] or [nnz, 3] + const torch::Tensor &opacities, // [C, N] or [nnz] + const torch::Tensor &view2gaussians, // [C, N, 10] or [nnz, 10] + const torch::Tensor &Ks, // [C, 3, 3] + const at::optional &backgrounds, // [C, 3] + // image size + const uint32_t image_width, const uint32_t image_height, const uint32_t tile_size, + // intersections + const torch::Tensor &tile_offsets, // [C, tile_height, tile_width] + const torch::Tensor &flatten_ids, // [n_isects] + // forward outputs + const torch::Tensor &render_alphas, // [C, image_height, image_width, 1] + const torch::Tensor &last_ids, // [C, image_height, image_width] + // gradients of outputs + const torch::Tensor &v_render_colors, // [C, image_height, image_width, 3] + const torch::Tensor &v_render_alphas, // [C, image_height, image_width, 1] + // options + bool absgrad) { + + CHECK_INPUT(colors); + uint32_t COLOR_DIM = colors.size(-1); + +#define __GS__CALL_(N) \ + case N: \ + return call_kernel_with_dim( \ + means2d, conics, colors, opacities, view2gaussians, Ks, \ + backgrounds, image_width, \ + image_height, tile_size, tile_offsets, flatten_ids, render_alphas, \ + last_ids, v_render_colors, v_render_alphas, absgrad); + + switch (COLOR_DIM) { + __GS__CALL_(1) + __GS__CALL_(2) + __GS__CALL_(3) + __GS__CALL_(4) + __GS__CALL_(5) + __GS__CALL_(8) + __GS__CALL_(9) + __GS__CALL_(16) + __GS__CALL_(17) + __GS__CALL_(32) + __GS__CALL_(33) + __GS__CALL_(64) + __GS__CALL_(65) + __GS__CALL_(128) + __GS__CALL_(129) + __GS__CALL_(256) + __GS__CALL_(257) + __GS__CALL_(512) + __GS__CALL_(513) + default: + AT_ERROR("Unsupported number of channels: ", COLOR_DIM); + } +} diff --git a/gsplat/cuda/csrc/raytracing_to_pixels_fwd.cu b/gsplat/cuda/csrc/raytracing_to_pixels_fwd.cu new file mode 100644 index 000000000..f4fa31002 --- /dev/null +++ b/gsplat/cuda/csrc/raytracing_to_pixels_fwd.cu @@ -0,0 +1,375 @@ +#include "bindings.h" +#include "helpers.cuh" +#include "types.cuh" +#include +#include +#include + +namespace cg = cooperative_groups; + +/**************************************************************************** + * Rasterization to Pixels Forward Pass + ****************************************************************************/ + +template +__global__ void raytracing_to_pixels_fwd_kernel( + const uint32_t C, const uint32_t N, const uint32_t n_isects, const bool packed, + const vec2 *__restrict__ means2d, // [C, N, 2] or [nnz, 2] + const vec3 *__restrict__ conics, // [C, N, 3] or [nnz, 3] + const S *__restrict__ colors, // [C, N, COLOR_DIM] or [nnz, COLOR_DIM] + const S *__restrict__ opacities, // [C, N] or [nnz] + const vec10 *__restrict__ view2gaussians, // [C, N, 10] or [nnz, 10] + const S *__restrict__ Ks, // [C, 3, 3] + const S *__restrict__ backgrounds, // [C, COLOR_DIM] + const uint32_t image_width, const uint32_t image_height, const uint32_t tile_size, + const uint32_t tile_width, const uint32_t tile_height, + const int32_t *__restrict__ tile_offsets, // [C, tile_height, tile_width] + const int32_t *__restrict__ flatten_ids, // [n_isects] + S *__restrict__ render_colors, // [C, image_height, image_width, COLOR_DIM] + S *__restrict__ render_alphas, // [C, image_height, image_width, 1] + int32_t *__restrict__ last_ids // [C, image_height, image_width] +) { + // each thread draws one pixel, but also timeshares caching gaussians in a + // shared tile + + auto block = cg::this_thread_block(); + int32_t camera_id = block.group_index().x; + int32_t tile_id = block.group_index().y * tile_width + block.group_index().z; + uint32_t i = block.group_index().y * tile_size + block.thread_index().y; + uint32_t j = block.group_index().z * tile_size + block.thread_index().x; + + tile_offsets += camera_id * tile_height * tile_width; + render_colors += camera_id * image_height * image_width * (COLOR_DIM + 1 + 3); + render_alphas += camera_id * image_height * image_width; + last_ids += camera_id * image_height * image_width * 2; + Ks += camera_id * 9; + + if (backgrounds != nullptr) { + backgrounds += camera_id * COLOR_DIM; + } + + S px = (S)j + 0.5f; + S py = (S)i + 0.5f; + int32_t pix_id = i * image_width + j; + + const S focal_x = Ks[0]; + const S focal_y = Ks[4]; + const S cx = Ks[2]; + const S cy = Ks[5]; + const vec3 ray = {(px - cx) / focal_x, (py - cy) / focal_y, 1.0}; + + // return if out of bounds + // keep not rasterizing threads around for reading data + bool inside = (i < image_height && j < image_width); + bool done = !inside; + + // have all threads in tile process the same gaussians in batches + // first collect gaussians between range.x and range.y in batches + // which gaussians to look through in this tile + int32_t range_start = tile_offsets[tile_id]; + int32_t range_end = + (camera_id == C - 1) && (tile_id == tile_width * tile_height - 1) + ? n_isects + : tile_offsets[tile_id + 1]; + const uint32_t block_size = block.size(); + uint32_t num_batches = (range_end - range_start + block_size - 1) / block_size; + + extern __shared__ int s[]; + int32_t *id_batch = (int32_t *)s; // [block_size] + vec3 *xy_opacity_batch = + reinterpret_cast *>(&id_batch[block_size]); // [block_size] + vec3 *conic_batch = + reinterpret_cast *>(&xy_opacity_batch[block_size]); // [block_size] + vec10 *view2gaussian_batch = + reinterpret_cast *>(&conic_batch[block_size]); // [block_size] + + // current visibility left to render + // transmittance is gonna be used in the backward pass which requires a high + // numerical precision so we use double for it. However double make bwd 1.5x slower + // so we stick with float for now. + S T = 1.0f; + // index of most recent gaussian to write to this thread's pixel + uint32_t cur_idx = 0; + uint32_t max_contributor = -1; + // collect and process batches of gaussians + // each thread loads one gaussian at a time before rasterizing its + // designated pixel + uint32_t tr = block.thread_rank(); + + // + 1 for depth and + 3 for normal + S pix_out[COLOR_DIM + 1 + 3] = {0.f}; + for (uint32_t b = 0; b < num_batches; ++b) { + // resync all threads before beginning next batch + // end early if entire tile is done + if (__syncthreads_count(done) >= block_size) { + break; + } + + // each thread fetch 1 gaussian from front to back + // index of gaussian to load + uint32_t batch_start = range_start + block_size * b; + uint32_t idx = batch_start + tr; + if (idx < range_end) { + int32_t g = flatten_ids[idx]; // flatten index in [C * N] or [nnz] + id_batch[tr] = g; + const vec2 xy = means2d[g]; + const S opac = opacities[g]; + xy_opacity_batch[tr] = {xy.x, xy.y, opac}; + conic_batch[tr] = conics[g]; + view2gaussian_batch[tr] = view2gaussians[g]; + } + + // wait for other threads to collect the gaussians in batch + block.sync(); + + // process gaussians in the current batch for this pixel + uint32_t batch_size = min(block_size, range_end - batch_start); + for (uint32_t t = 0; (t < batch_size) && !done; ++t) { + const vec3 conic = conic_batch[t]; + const vec3 xy_opac = xy_opacity_batch[t]; + const S opac = xy_opac.z; + // why not use pointer so we don't need to copy again? + const vec10 view2gaussian = view2gaussian_batch[t]; + + const vec3 normal = { + view2gaussian[0] * ray.x + view2gaussian[1] * ray.y + view2gaussian[2], + view2gaussian[1] * ray.x + view2gaussian[3] * ray.y + view2gaussian[4], + view2gaussian[2] * ray.x + view2gaussian[4] * ray.y + view2gaussian[5] + }; + + // use AA, BB, CC so that the name is unique + S AA = ray.x * normal[0] + ray.y * normal[1] + normal[2]; + S BB = 2 * (view2gaussian[6] * ray.x + view2gaussian[7] * ray.y + view2gaussian[8]); + S CC = view2gaussian[9]; + + // t is the depth of the gaussian + S depth = -BB/(2*AA); + + //TODO take near plane as input + #define NEAR_PLANE 0.01f + // depth must be positive otherwise it is not valid and we skip it + if (depth <= NEAR_PLANE) + continue; + + // the scale of the gaussian is 1.f / sqrt(AA) + S min_value = -(BB/AA) * (BB/4.) + CC; + + S power = -0.5f * min_value; + if (power > 0.0f){ + power = 0.0f; + } + + S alpha = min(0.999f, opac * exp(power)); + + if (alpha < 1.f / 255.f) { + continue; + } + + const S next_T = T * (1.0f - alpha); + if (next_T <= 1e-4) { // this pixel is done: exclusive + done = true; + break; + } + + int32_t g = id_batch[t]; + const S vis = alpha * T; + const S *c_ptr = colors + g * COLOR_DIM; + PRAGMA_UNROLL + for (uint32_t k = 0; k < COLOR_DIM; ++k) { + pix_out[k] += c_ptr[k] * vis; + } + + // render depth, normal, distortion + // NDC mapping is taken from 2DGS paper, please check here https://arxiv.org/pdf/2403.17888.pdf + // const float max_t = t; + // const float mapped_max_t = (FAR_PLANE * max_t - FAR_PLANE * NEAR_PLANE) / ((FAR_PLANE - NEAR_PLANE) * max_t); + + // normalize normal + const S length = sqrt(normal[0] * normal[0] + normal[1] * normal[1] + normal[2] * normal[2] + 1e-7); + const vec3 normal_normalized = { -normal[0] / length, -normal[1] / length, -normal[2] / length }; + + // distortion loss is taken from 2DGS paper, please check https://arxiv.org/pdf/2403.17888.pdf + // float A = 1-T; + // float error = mapped_max_t * mapped_max_t * A + dist2 - 2 * mapped_max_t * dist1; + // distortion += error * alpha * T; + + // dist1 += mapped_max_t * alpha * T; + // dist2 += mapped_max_t * mapped_max_t * alpha * T; + + // normal + for (int k = 0; k < 3; k++) + pix_out[COLOR_DIM + k] += normal_normalized[k] * vis; + + // depth and alpha + if (T > 0.5){ + pix_out[COLOR_DIM + 3] = depth; + max_contributor = batch_start + t; + } + + cur_idx = batch_start + t; + + T = next_T; + } + } + + if (inside) { + // Here T is the transmittance AFTER the last gaussian in this pixel. + // We (should) store double precision as T would be used in backward pass and + // it can be very small and causing large diff in gradients with float32. + // However, double precision makes the backward pass 1.5x slower so we stick + // with float for now. + render_alphas[pix_id] = 1.0f - T; + PRAGMA_UNROLL + for (uint32_t k = 0; k < COLOR_DIM; ++k) { + render_colors[pix_id * (COLOR_DIM + 1 + 3) + k] = + backgrounds == nullptr ? pix_out[k] : (pix_out[k] + T * backgrounds[k]); + } + // normal + PRAGMA_UNROLL + for (uint32_t k = 0; k < 3; ++k) { + render_colors[pix_id * (COLOR_DIM + 1 + 3) + COLOR_DIM + k] = pix_out[COLOR_DIM + k]; + } + // depth + render_colors[pix_id * (COLOR_DIM + 1 + 3) + COLOR_DIM + 3] = pix_out[COLOR_DIM + 3]; + + // index in bin of last gaussian in this pixel + last_ids[pix_id * 2] = static_cast(cur_idx); + // index in bn of the gaussian that contributes the most to this pixel/depth + last_ids[pix_id * 2 + 1] = static_cast(max_contributor); + } +} + +template +std::tuple call_kernel_with_dim( + // Gaussian parameters + const torch::Tensor &means2d, // [C, N, 2] or [nnz, 2] + const torch::Tensor &conics, // [C, N, 3] or [nnz, 3] + const torch::Tensor &colors, // [C, N, channels] or [nnz, channels] + const torch::Tensor &opacities, // [C, N] or [nnz] + const torch::Tensor &view2gaussians, // [C, N, 10] or [nnz, 10] + const torch::Tensor &Ks, // [C, 3, 3] + const at::optional &backgrounds, // [C, channels] + // image size + const uint32_t image_width, const uint32_t image_height, const uint32_t tile_size, + // intersections + const torch::Tensor &tile_offsets, // [C, tile_height, tile_width] + const torch::Tensor &flatten_ids // [n_isects] +) { + DEVICE_GUARD(means2d); + CHECK_INPUT(means2d); + CHECK_INPUT(conics); + CHECK_INPUT(colors); + CHECK_INPUT(opacities); + CHECK_INPUT(view2gaussians); + CHECK_INPUT(Ks); + CHECK_INPUT(tile_offsets); + CHECK_INPUT(flatten_ids); + if (backgrounds.has_value()) { + CHECK_INPUT(backgrounds.value()); + } + bool packed = means2d.dim() == 2; + + uint32_t C = tile_offsets.size(0); // number of cameras + uint32_t N = packed ? 0 : means2d.size(1); // number of gaussians + uint32_t channels = colors.size(-1); + uint32_t tile_height = tile_offsets.size(1); + uint32_t tile_width = tile_offsets.size(2); + uint32_t n_isects = flatten_ids.size(0); + + // Each block covers a tile on the image. In total there are + // C * tile_height * tile_width blocks. + dim3 threads = {tile_size, tile_size, 1}; + dim3 blocks = {C, tile_height, tile_width}; + + // + 1 for depth and + 3 for normal + torch::Tensor renders = torch::empty({C, image_height, image_width, channels + 1 + 3}, + means2d.options().dtype(torch::kFloat32)); + torch::Tensor alphas = torch::empty({C, image_height, image_width, 1}, + means2d.options().dtype(torch::kFloat32)); + // 1 for last_ids and 1 for max_contributor + torch::Tensor last_ids = torch::empty({C, image_height, image_width, 2}, + means2d.options().dtype(torch::kInt32)); + + at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream(); + const uint32_t shared_mem = + tile_size * tile_size * + (sizeof(int32_t) + sizeof(vec3) + sizeof(vec3) + sizeof(vec10)); + + // TODO: an optimization can be done by passing the actual number of channels into + // the kernel functions and avoid necessary global memory writes. This requires + // moving the channel padding from python to C side. + if (cudaFuncSetAttribute(raytracing_to_pixels_fwd_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + shared_mem) != cudaSuccess) { + AT_ERROR("Failed to set maximum shared memory size (requested ", shared_mem, + " bytes), try lowering tile_size."); + } + raytracing_to_pixels_fwd_kernel + <<>>( + C, N, n_isects, packed, + reinterpret_cast *>(means2d.data_ptr()), + reinterpret_cast *>(conics.data_ptr()), + colors.data_ptr(), opacities.data_ptr(), + reinterpret_cast *>(view2gaussians.data_ptr()), + Ks.data_ptr(), + backgrounds.has_value() ? backgrounds.value().data_ptr() : nullptr, + image_width, image_height, tile_size, tile_width, tile_height, + tile_offsets.data_ptr(), flatten_ids.data_ptr(), + renders.data_ptr(), alphas.data_ptr(), + last_ids.data_ptr()); + + return std::make_tuple(renders, alphas, last_ids); +} + +std::tuple raytracing_to_pixels_fwd_tensor( + // Gaussian parameters + const torch::Tensor &means2d, // [C, N, 2] or [nnz, 2] + const torch::Tensor &conics, // [C, N, 3] or [nnz, 3] + const torch::Tensor &colors, // [C, N, channels] or [nnz, channels] + const torch::Tensor &opacities, // [C, N] or [nnz] + const torch::Tensor &view2gaussians, // [C, N, 10] or [nnz, 10] + const torch::Tensor &Ks, // [C, 3, 3] + const at::optional &backgrounds, // [C, channels] + // image size + const uint32_t image_width, const uint32_t image_height, const uint32_t tile_size, + // intersections + const torch::Tensor &tile_offsets, // [C, tile_height, tile_width] + const torch::Tensor &flatten_ids // [n_isects] +) { + CHECK_INPUT(colors); + uint32_t channels = colors.size(-1); + +#define __GS__CALL_(N) \ + case N: \ + return call_kernel_with_dim(means2d, conics, colors, opacities, \ + view2gaussians, Ks, \ + backgrounds, image_width, image_height, \ + tile_size, tile_offsets, flatten_ids); + + // TODO: an optimization can be done by passing the actual number of channels into + // the kernel functions and avoid necessary global memory writes. This requires + // moving the channel padding from python to C side. + switch (channels) { + __GS__CALL_(1) + __GS__CALL_(2) + __GS__CALL_(3) + __GS__CALL_(4) + __GS__CALL_(5) + __GS__CALL_(8) + __GS__CALL_(9) + __GS__CALL_(16) + __GS__CALL_(17) + __GS__CALL_(32) + __GS__CALL_(33) + __GS__CALL_(64) + __GS__CALL_(65) + __GS__CALL_(128) + __GS__CALL_(129) + __GS__CALL_(256) + __GS__CALL_(257) + __GS__CALL_(512) + __GS__CALL_(513) + default: + AT_ERROR("Unsupported number of channels: ", channels); + } +} diff --git a/gsplat/cuda/csrc/types.cuh b/gsplat/cuda/csrc/types.cuh index 116059444..49ab85dc1 100644 --- a/gsplat/cuda/csrc/types.cuh +++ b/gsplat/cuda/csrc/types.cuh @@ -7,6 +7,7 @@ #include #include +#include #include @@ -16,6 +17,8 @@ template using vec3 = glm::vec<3, T>; template using vec4 = glm::vec<4, T>; +template using vec10 = std::array; + template using mat2 = glm::mat<2, 2, T>; template using mat3 = glm::mat<3, 3, T>; diff --git a/gsplat/cuda/csrc/view_to_gaussians_bwd.cu b/gsplat/cuda/csrc/view_to_gaussians_bwd.cu new file mode 100644 index 000000000..f11647dee --- /dev/null +++ b/gsplat/cuda/csrc/view_to_gaussians_bwd.cu @@ -0,0 +1,252 @@ +#include "bindings.h" +#include "helpers.cuh" +#include "utils.cuh" + +#include +#include +#include +#include +#include + +namespace cg = cooperative_groups; + + +/**************************************************************************** + * Projection of Gaussians (Single Batch) Backward Pass + ****************************************************************************/ + +template +__global__ void view_to_gaussians_bwd_kernel( + // fwd inputs + const uint32_t C, const uint32_t N, + const T *__restrict__ means, // [N, 3] + const T *__restrict__ quats, // [N, 4] + const T *__restrict__ scales, // [N, 3] + const T *__restrict__ camtoworlds, // [C, 4, 4] TODO world-to-cam + const int32_t *__restrict__ radii, // [C, N] + // fwd outputs + const T *__restrict__ view2gaussians, // [C, N, 10] + // grad outputs + const T *__restrict__ v_view2gaussians, // [C, N, 10] + // grad inputs + T *__restrict__ v_means, // [N, 3] + T *__restrict__ v_quats, // [N, 4] + T *__restrict__ v_scales, // [N, 3] + T *__restrict__ v_camtoworlds // [C, 4, 4] optional +) { + // parallelize over C * N. + uint32_t idx = cg::this_grid().thread_rank(); + if (idx >= C * N || radii[idx] <= 0) { + return; + } + const uint32_t cid = idx / N; // camera id + const uint32_t gid = idx % N; // gaussian id + + // shift pointers to the current camera and gaussian + means += gid * 3; + camtoworlds += cid * 16; + quats += gid * 4; + scales += gid * 3; + view2gaussians += idx * 10; + v_view2gaussians += idx * 10; + + // glm is column-major but input is row-major + mat3 camtoworlds_R = mat3( + camtoworlds[0], camtoworlds[4], camtoworlds[8], // 1st column + camtoworlds[1], camtoworlds[5], camtoworlds[9], // 2nd column + camtoworlds[2], camtoworlds[6], camtoworlds[10] // 3rd column + ); + vec3 camtoworlds_t = vec3(camtoworlds[3], camtoworlds[7], camtoworlds[11]); + + vec3 mean = glm::make_vec3(means); + vec4 quat = glm::make_vec4(quats); + vec3 scale = glm::make_vec3(scales); + mat3 rotmat = quat_to_rotmat(quat); + + mat3 worldtogaussian_R = glm::transpose(rotmat); + vec3 worldtogaussian_t = -worldtogaussian_R * mean; + + mat3 view2gaussian_R = worldtogaussian_R * camtoworlds_R; + vec3 view2gaussian_t = worldtogaussian_R * camtoworlds_t + worldtogaussian_t; + + vec3 scales_inv_square = {1.0f / (scale.x * scale.x + 1e-10f), 1.0f / (scale.y * scale.y + 1e-10f), 1.0f / (scale.z * scale.z + 1e-10f)}; + T CC = view2gaussian_t.x * view2gaussian_t.x * scales_inv_square.x + \ + view2gaussian_t.y * view2gaussian_t.y * scales_inv_square.y + \ + view2gaussian_t.z * view2gaussian_t.z * scales_inv_square.z; + + mat3 scales_inv_square_R = mat3( + scales_inv_square.x * view2gaussian_R[0][0], scales_inv_square.y * view2gaussian_R[0][1], scales_inv_square.z * view2gaussian_R[0][2], + scales_inv_square.x * view2gaussian_R[1][0], scales_inv_square.y * view2gaussian_R[1][1], scales_inv_square.z * view2gaussian_R[1][2], + scales_inv_square.x * view2gaussian_R[2][0], scales_inv_square.y * view2gaussian_R[2][1], scales_inv_square.z * view2gaussian_R[2][2] + ); + + // gradient + mat3 dL_dSigma = mat3( + v_view2gaussians[0], 0.5f * v_view2gaussians[1], 0.5f * v_view2gaussians[2], + 0.5f * v_view2gaussians[1], v_view2gaussians[3], 0.5f * v_view2gaussians[4], + 0.5f * v_view2gaussians[2], 0.5f * v_view2gaussians[4], v_view2gaussians[5] + ); + vec3 dL_dB = vec3(v_view2gaussians[6], v_view2gaussians[7], v_view2gaussians[8]); + T dL_dCC = v_view2gaussians[9]; + + // TODO + // vec3 BB = view2gaussian_t * scales_inv_square_R; + // mat3 Sigma = glm::transpose(view2gaussian_R) * scales_inv_square_R; + mat3 dL_dS_inv_square_R = view2gaussian_R * dL_dSigma + glm::outerProduct(view2gaussian_t, dL_dB); + mat3 v_view2gaussian_R = glm::transpose(dL_dSigma * glm::transpose(scales_inv_square_R)); + + // mat3 scales_inv_square_R = mat3( + // scales_inv_square.x * view2gaussian_R[0][0], scales_inv_square.y * view2gaussian_R[0][1], scales_inv_square.z * view2gaussian_R[0][2], + // scales_inv_square.x * view2gaussian_R[1][0], scales_inv_square.y * view2gaussian_R[1][1], scales_inv_square.z * view2gaussian_R[1][2], + // scales_inv_square.x * view2gaussian_R[2][0], scales_inv_square.y * view2gaussian_R[2][1], scales_inv_square.z * view2gaussian_R[2][2] + // ); + v_view2gaussian_R += mat3( + scales_inv_square.x * dL_dS_inv_square_R[0][0], scales_inv_square.y * dL_dS_inv_square_R[0][1], scales_inv_square.z * dL_dS_inv_square_R[0][2], + scales_inv_square.x * dL_dS_inv_square_R[1][0], scales_inv_square.y * dL_dS_inv_square_R[1][1], scales_inv_square.z * dL_dS_inv_square_R[1][2], + scales_inv_square.x * dL_dS_inv_square_R[2][0], scales_inv_square.y * dL_dS_inv_square_R[2][1], scales_inv_square.z * dL_dS_inv_square_R[2][2] + ); + vec3 dL_dS_inv_square = vec3( + dL_dS_inv_square_R[0][0] * view2gaussian_R[0][0] + dL_dS_inv_square_R[1][0] * view2gaussian_R[1][0] + dL_dS_inv_square_R[2][0] * view2gaussian_R[2][0], + dL_dS_inv_square_R[0][1] * view2gaussian_R[0][1] + dL_dS_inv_square_R[1][1] * view2gaussian_R[1][1] + dL_dS_inv_square_R[2][1] * view2gaussian_R[2][1], + dL_dS_inv_square_R[0][2] * view2gaussian_R[0][2] + dL_dS_inv_square_R[1][2] * view2gaussian_R[1][2] + dL_dS_inv_square_R[2][2] * view2gaussian_R[2][2] + ); + // T CC = view2gaussian_t.x * view2gaussian_t.x * scales_inv_square.x + \ + // view2gaussian_t.y * view2gaussian_t.y * scales_inv_square.y + \ + // view2gaussian_t.z * view2gaussian_t.z * scales_inv_square.z; + vec3 v_view2gaussian_t = vec3( + 2.f * view2gaussian_t.x * scales_inv_square.x * dL_dCC + dL_dB.x * scales_inv_square_R[0][0] + dL_dB.y * scales_inv_square_R[1][0] + dL_dB.z * scales_inv_square_R[2][0], + 2.f * view2gaussian_t.y * scales_inv_square.y * dL_dCC + dL_dB.x * scales_inv_square_R[0][1] + dL_dB.y * scales_inv_square_R[1][1] + dL_dB.z * scales_inv_square_R[2][1], + 2.f * view2gaussian_t.z * scales_inv_square.z * dL_dCC + dL_dB.x * scales_inv_square_R[0][2] + dL_dB.y * scales_inv_square_R[1][2] + dL_dB.z * scales_inv_square_R[2][2] + ); + dL_dS_inv_square.x += dL_dCC * view2gaussian_t.x * view2gaussian_t.x; + dL_dS_inv_square.y += dL_dCC * view2gaussian_t.y * view2gaussian_t.y; + dL_dS_inv_square.z += dL_dCC * view2gaussian_t.z * view2gaussian_t.z; + // vec3 scales_inv_square = {1.0f / (scale.x * scale.x + 1e-7), 1.0f / (scale.y * scale.y + 1e-7), 1.0f / (scale.z * scale.z + 1e-7)}; + vec3 v_scale = vec3( + -2.f / scale.x * scales_inv_square.x * dL_dS_inv_square.x, + -2.f / scale.y * scales_inv_square.y * dL_dS_inv_square.y, + -2.f / scale.z * scales_inv_square.z * dL_dS_inv_square.z + ); + + // mat3 view2gaussian_R = worldtogaussian_R * camtoworlds_R; + // vec3 view2gaussian_t = worldtogaussian_R * camtoworlds_t + worldtogaussian_t; + vec3 v_worldtogaussian_t = v_view2gaussian_t; + mat3 v_worldtogaussian_R = v_view2gaussian_R * glm::transpose(camtoworlds_R) + glm::outerProduct(v_view2gaussian_t, camtoworlds_t); + // vec3 worldtogaussian_t = -worldtogaussian_R * mean; + v_worldtogaussian_R -= glm::outerProduct(v_worldtogaussian_t, mean); + vec3 v_mean = -glm::transpose(worldtogaussian_R) * v_worldtogaussian_t; + // mat3 worldtogaussian_R = glm::transpose(rotmat); + mat3 v_rotmat = glm::transpose(v_worldtogaussian_R); + + // grad for quat rotmat + vec4 v_quat(0.f); + quat_to_rotmat_vjp(quat, v_rotmat, v_quat); + + // #if __CUDA_ARCH__ >= 700 + // write out results with warp-level reduction + auto warp = cg::tiled_partition<32>(cg::this_thread_block()); + auto warp_group_g = cg::labeled_partition(warp, gid); + warpSum(v_mean, warp_group_g); + if (warp_group_g.thread_rank() == 0) { + v_means += gid * 3; + PRAGMA_UNROLL + for (uint32_t i = 0; i < 3; i++) { + gpuAtomicAdd(v_means + i, v_mean[i]); + } + } + // Directly output gradients w.r.t. the quaternion and scale + warpSum(v_quat, warp_group_g); + warpSum(v_scale, warp_group_g); + if (warp_group_g.thread_rank() == 0) { + v_quats += gid * 4; + v_scales += gid * 3; + gpuAtomicAdd(v_quats, v_quat[0]); + gpuAtomicAdd(v_quats + 1, v_quat[1]); + gpuAtomicAdd(v_quats + 2, v_quat[2]); + gpuAtomicAdd(v_quats + 3, v_quat[3]); + gpuAtomicAdd(v_scales, v_scale[0]); + gpuAtomicAdd(v_scales + 1, v_scale[1]); + gpuAtomicAdd(v_scales + 2, v_scale[2]); + } + // not supported yet for viewmats + // if (v_viewmats != nullptr) { + // auto warp_group_c = cg::labeled_partition(warp, cid); + // warpSum(v_R, warp_group_c); + // warpSum(v_t, warp_group_c); + // if (warp_group_c.thread_rank() == 0) { + // v_viewmats += cid * 16; + // PRAGMA_UNROLL + // for (uint32_t i = 0; i < 3; i++) { // rows + // PRAGMA_UNROLL + // for (uint32_t j = 0; j < 3; j++) { // cols + // gpuAtomicAdd(v_viewmats + i * 4 + j, v_R[j][i]); + // } + // gpuAtomicAdd(v_viewmats + i * 4 + 3, v_t[i]); + // } + // } + // } +} + +std::tuple +view_to_gaussians_bwd_tensor( + // fwd inputs + const torch::Tensor &means, // [N, 3] + const torch::Tensor &quats, // [N, 4] + const torch::Tensor &scales, // [N, 3] + const torch::Tensor &viewmats, // [C, 4, 4] + const torch::Tensor &radii, // [C, N] + // fwd outputs + const torch::Tensor &view2gaussians, // [C, N, 10] + // grad outputs + const torch::Tensor &v_view2gaussians, // [C, N, 10] + const bool viewmats_requires_grad) { + DEVICE_GUARD(means); + CHECK_INPUT(means); + CHECK_INPUT(quats); + CHECK_INPUT(scales); + CHECK_INPUT(viewmats); + CHECK_INPUT(radii); + CHECK_INPUT(view2gaussians); + CHECK_INPUT(v_view2gaussians); + + uint32_t N = means.size(0); // number of gaussians + uint32_t C = viewmats.size(0); // number of cameras + at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream(); + + torch::Tensor v_means = torch::zeros_like(means); + torch::Tensor v_quats = torch::zeros_like(quats); + torch::Tensor v_scales = torch::zeros_like(scales); + torch::Tensor v_viewmats; + if (viewmats_requires_grad) { + v_viewmats = torch::zeros_like(viewmats); + } + if (C && N) { + view_to_gaussians_bwd_kernel<<<(C * N + N_THREADS - 1) / N_THREADS, N_THREADS, 0, stream>>>( + C, N, + means.data_ptr(), + quats.data_ptr(), + scales.data_ptr(), + viewmats.data_ptr(), + radii.data_ptr(), + view2gaussians.data_ptr(), + v_view2gaussians.data_ptr(), + v_means.data_ptr(), + v_quats.data_ptr(), + v_scales.data_ptr(), + viewmats_requires_grad ? v_viewmats.data_ptr() : nullptr); + // view_to_gaussians_bwd_kernel<<<(C * N + N_THREADS - 1) / N_THREADS, N_THREADS, 0, stream>>>( + // C, N, + // means.data_ptr(), + // quats.data_ptr(), + // scales.data_ptr(), + // viewmats.data_ptr(), + // radii.data_ptr(), + // view2gaussians.data_ptr(), + // v_view2gaussians.data_ptr(), + // v_means.data_ptr(), + // v_quats.data_ptr(), + // v_scales.data_ptr(), + // viewmats_requires_grad ? v_viewmats.data_ptr() : nullptr); + } + return std::make_tuple(v_means, v_quats, v_scales, v_viewmats); +} diff --git a/gsplat/cuda/csrc/view_to_gaussians_fwd.cu b/gsplat/cuda/csrc/view_to_gaussians_fwd.cu new file mode 100644 index 000000000..9b5331010 --- /dev/null +++ b/gsplat/cuda/csrc/view_to_gaussians_fwd.cu @@ -0,0 +1,140 @@ +#include "bindings.h" +#include "helpers.cuh" +#include "utils.cuh" + +#include +#include +#include +#include +#include + +namespace cg = cooperative_groups; + + +/**************************************************************************** + * Projection of Gaussians (Single Batch) Forward Pass + ****************************************************************************/ + +template +__global__ void +view_to_gaussians_fwd_kernel(const uint32_t C, const uint32_t N, + const T *__restrict__ means, // [N, 3] + const T *__restrict__ quats, // [N, 4] + const T *__restrict__ scales, // [N, 3] + const T *__restrict__ camtoworlds, // [C, 4, 4] + const int32_t *__restrict__ radii, // [C, N] + // outputs + T *__restrict__ view2gaussians // [C, N, 10] +) { + // parallelize over C * N. + uint32_t idx = cg::this_grid().thread_rank(); + if (idx >= C * N || radii[idx] <= 0) { + return; + } + const uint32_t cid = idx / N; // camera id + const uint32_t gid = idx % N; // gaussian id + + // shift pointers to the current camera and gaussian + means += gid * 3; + camtoworlds += cid * 16; + quats += gid * 4; + scales += gid * 3; + view2gaussians += idx * 10; + + // glm is column-major but input is row-major + mat3 camtoworlds_R = mat3( + camtoworlds[0], camtoworlds[4], camtoworlds[8], // 1st column + camtoworlds[1], camtoworlds[5], camtoworlds[9], // 2nd column + camtoworlds[2], camtoworlds[6], camtoworlds[10] // 3rd column + ); + vec3 camtoworlds_t = vec3(camtoworlds[3], camtoworlds[7], camtoworlds[11]); + + vec3 mean = glm::make_vec3(means); + vec4 quat = glm::make_vec4(quats); + vec3 scale = glm::make_vec3(scales); + mat3 rotmat = quat_to_rotmat(quat); + + mat3 worldtogaussian_R = glm::transpose(rotmat); + vec3 worldtogaussian_t = -worldtogaussian_R * mean; + + mat3 view2gaussian_R = worldtogaussian_R * camtoworlds_R; + vec3 view2gaussian_t = worldtogaussian_R * camtoworlds_t + worldtogaussian_t; + + // precompute the value here to avoid repeated computations also reduce IO + // v is the viewdirection and v^T is the transpose of v + // t = position of the camera in the gaussian coordinate system + // A = v^T @ R^T @ S^-1 @ S^-1 @ R @ v + // B = t^T @ S^-1 @ S^-1 @ R @ v + // C = t^T @ S^-1 @ S^-1 @ t + // For the given caemra, t is fix and v depends on the pixel + // therefore we can precompute A, B, C and use them in the forward pass + // For A, we can precompute R^T @ S^-1 @ S^-1 @ R, which is a symmetric matrix and only store the upper triangle in 6 values + // For B, we can precompute S^-1 @ S^-1 @ R @ v, which is a vector and store it in 3 values + // and C is fixed, so we only need to store 1 value + // Therefore, we only need to store 10 values in the view2gaussian matrix + vec3 scales_inv_square = {1.0f / (scale.x * scale.x + 1e-10f), 1.0f / (scale.y * scale.y + 1e-10f), 1.0f / (scale.z * scale.z + 1e-10f)}; + T CC = view2gaussian_t.x * view2gaussian_t.x * scales_inv_square.x + \ + view2gaussian_t.y * view2gaussian_t.y * scales_inv_square.y + \ + view2gaussian_t.z * view2gaussian_t.z * scales_inv_square.z; + + mat3 scales_inv_square_R = mat3( + scales_inv_square.x * view2gaussian_R[0][0], scales_inv_square.y * view2gaussian_R[0][1], scales_inv_square.z * view2gaussian_R[0][2], + scales_inv_square.x * view2gaussian_R[1][0], scales_inv_square.y * view2gaussian_R[1][1], scales_inv_square.z * view2gaussian_R[1][2], + scales_inv_square.x * view2gaussian_R[2][0], scales_inv_square.y * view2gaussian_R[2][1], scales_inv_square.z * view2gaussian_R[2][2] + ); + + vec3 BB = view2gaussian_t * scales_inv_square_R; + mat3 Sigma = glm::transpose(view2gaussian_R) * scales_inv_square_R; + + // write to view2gaussian + view2gaussians[0] = Sigma[0][0]; + view2gaussians[1] = Sigma[0][1]; + view2gaussians[2] = Sigma[0][2]; + view2gaussians[3] = Sigma[1][1]; + view2gaussians[4] = Sigma[1][2]; + view2gaussians[5] = Sigma[2][2]; + view2gaussians[6] = BB.x; + view2gaussians[7] = BB.y; + view2gaussians[8] = BB.z; + view2gaussians[9] = CC; +} + + +torch::Tensor view_to_gaussians_fwd_tensor( + const torch::Tensor &means, // [N, 3] + const torch::Tensor &quats, // [N, 4] + const torch::Tensor &scales, // [N, 3] + const torch::Tensor &viewmats, // [C, 4, 4] + const torch::Tensor &radii // [C, N] +) { + DEVICE_GUARD(means); + CHECK_INPUT(means); + CHECK_INPUT(quats); + CHECK_INPUT(scales); + CHECK_INPUT(viewmats); + CHECK_INPUT(radii); + + uint32_t N = means.size(0); // number of gaussians + uint32_t C = viewmats.size(0); // number of cameras + at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream(); + + torch::Tensor view2gaussians = torch::empty({C, N, 10}, means.options()); + + if (C && N) { + view_to_gaussians_fwd_kernel<<<(C * N + N_THREADS - 1) / N_THREADS, N_THREADS, 0, stream>>>( + C, N, means.data_ptr(), + quats.data_ptr(), + scales.data_ptr(), + viewmats.data_ptr(), + radii.data_ptr(), + view2gaussians.data_ptr()); + // view_to_gaussians_fwd_kernel<<<(C * N + N_THREADS - 1) / N_THREADS, N_THREADS, 0, stream>>>( + // C, N, means.data_ptr(), + // quats.data_ptr(), + // scales.data_ptr(), + // viewmats.data_ptr(), + // radii.data_ptr(), + // view2gaussians.data_ptr()); + } + return view2gaussians; +} diff --git a/gsplat/rendering.py b/gsplat/rendering.py index 7339e5b2f..dc0c5e554 100644 --- a/gsplat/rendering.py +++ b/gsplat/rendering.py @@ -3,14 +3,20 @@ import torch from torch import Tensor +import torch.nn.functional as F from typing_extensions import Literal from .cuda._wrapper import ( fully_fused_projection, isect_offset_encode, isect_tiles, + points_isect_tiles, rasterize_to_pixels, + raytracing_to_pixels, + integrate_to_points, spherical_harmonics, + view_to_gaussians, + project_points, ) @@ -622,3 +628,836 @@ def _getProjectionMatrix(znear, zfar, fovX, fovY, device="cuda"): render_colors.append(render_colors_) render_colors = torch.stack(render_colors, dim=0) return render_colors, None, {} + + +def _view_to_gaussians( + means: Tensor, # [N, 3] + quats: Tensor, # [N, 4] + scales: Tensor, # [N, 3] + camtoworlds: Tensor, # [C, 4, 4] world to camera transform +) -> Tensor: + + quats = F.normalize(quats, p=2, dim=-1) + w, x, y, z = torch.unbind(quats, dim=-1) + R = torch.stack( + [ + 1 - 2 * (y**2 + z**2), + 2 * (x * y - w * z), + 2 * (x * z + w * y), + 2 * (x * y + w * z), + 1 - 2 * (x**2 + z**2), + 2 * (y * z - w * x), + 2 * (x * z - w * y), + 2 * (y * z + w * x), + 1 - 2 * (x**2 + y**2), + ], + dim=-1, + ) + + R = R.reshape(quats.shape[:-1] + (3, 3)) # (..., 3, 3) + # R.register_hook(lambda grad: print("grad R", grad)) + + worldtogaussian = torch.zeros( + (means.shape[0], 4, 4), device=means.device, dtype=means.dtype + ) + worldtogaussian[:, :3, :3] = R.transpose(1, 2) # (..., 3, 3) + worldtogaussian[:, :3, 3:] = -R.transpose(1, 2) @ means[:, :, None] # (..., 3) + worldtogaussian[:, 3, 3] = 1.0 + + view2gaussians = ( + worldtogaussian[None, ...] @ camtoworlds[:, None, ...] + ) # [C, N, 4, 4] + R = view2gaussians[..., :3, :3] + t = view2gaussians[..., :3, 3:] + scales_inv_square = 1.0 / (scales**2 + 1e-10) + + C = torch.sum((t**2) * scales_inv_square[None, :, :, None], dim=2) # [C, N, 3, 1] + scales_inv_square_R = scales_inv_square[None, :, :, None] * R + B = t.transpose(-2, -1) @ scales_inv_square_R + Sigma = R.transpose(-2, -1) @ scales_inv_square_R + merged = torch.cat( + [Sigma[:, :, :, 0], Sigma[:, :, 1:, 1], Sigma[:, :, 2:, 2], B.squeeze(-2), C], + dim=-1, + ) + + return merged + + +def raytracing( + means: Tensor, # [N, 3] + quats: Tensor, # [N, 4] + scales: Tensor, # [N, 3] + opacities: Tensor, # [N] + colors: Tensor, # [(C,) N, D] or [(C,) N, K, 3] + viewmats: Tensor, # [C, 4, 4] + Ks: Tensor, # [C, 3, 3] + width: int, + height: int, + near_plane: float = 0.01, + far_plane: float = 1e10, + radius_clip: float = 0.0, + eps2d: float = 0.3, + sh_degree: Optional[int] = None, + packed: bool = True, + tile_size: int = 16, + backgrounds: Optional[Tensor] = None, + render_mode: Literal["RGB", "D", "ED", "RGB+D", "RGB+ED"] = "RGB", + sparse_grad: bool = False, + absgrad: bool = False, + rasterize_mode: Literal["classic", "antialiased"] = "classic", + channel_chunk: int = 32, +) -> Tuple[Tensor, Tensor, Dict]: + """Render a set of 3D Gaussians (N) to a batch of image planes (C) using raytracing as in Gaussian Opacity Fields. + + TODO: update the doc to raytracing. + This function provides a handful features for 3D Gaussian rasterization, which + we detail in the following notes. A complete profiling of the these features + can be found in the :ref:`profiling` page. + + .. note:: + **Batch Rasterization**: This function allows for rasterizing a set of 3D Gaussians + to a batch of images in one go, by simplly providing the batched `viewmats` and `Ks`. + + .. note:: + **Support N-D Features**: If `sh_degree` is None, + the `colors` is expected to be with shape [N, D] or [C, N, D], in which D is the channel of + the features to be rendered. The computation is slow when D > 32 at the moment. + If `sh_degree` is set, the `colors` is expected to be the SH coefficients with + shape [N, K, 3] or [C, N, K, 3], where K is the number of SH bases. In this case, it is expected + that :math:`(\\textit{sh_degree} + 1) ^ 2 \\leq K`, where `sh_degree` controls the + activated bases in the SH coefficients. + + .. note:: + **Depth Rendering**: This function supports colors or/and depths via `render_mode`. + The supported modes are "RGB", "D", "ED", "RGB+D", and "RGB+ED". "RGB" renders the + colored image that respects the `colors` argument. "D" renders the accumulated z-depth + :math:`\\sum_i w_i z_i`. "ED" renders the expected z-depth + :math:`\\frac{\\sum_i w_i z_i}{\\sum_i w_i}`. "RGB+D" and "RGB+ED" render both + the colored image and the depth, in which the depth is the last channel of the output. + + .. note:: + **Memory-Speed Trade-off**: The `packed` argument provides a trade-off between + memory footprint and runtime. If `packed` is True, the intermediate results are + packed into sparse tensors, which is more memory efficient but might be slightly + slower. This is especially helpful when the scene is large and each camera sees only + a small portion of the scene. If `packed` is False, the intermediate results are + with shape [C, N, ...], which is faster but might consume more memory. + + .. note:: + **Sparse Gradients**: If `sparse_grad` is True, the gradients for {means, quats, scales} + will be stored in a `COO sparse layout `_. + This can be helpful for saving memory + for training when the scene is large and each iteration only activates a small portion + of the Gaussians. Usually a sparse optimizer is required to work with sparse gradients, + such as `torch.optim.SparseAdam `_. + This argument is only effective when `packed` is True. + + .. note:: + **Speed-up for Large Scenes**: The `radius_clip` argument is extremely helpful for + speeding up large scale scenes or scenes with large depth of fields. Gaussians with + 2D radius smaller or equal than this value (in pixel unit) will be skipped during rasterization. + This will skip all the far-away Gaussians that are too small to be seen in the image. + But be warned that if there are close-up Gaussians that are also below this threshold, they will + also get skipped (which is rarely happened in practice). This is by default disabled by setting + `radius_clip` to 0.0. + + .. note:: + **Antialiased Rendering**: If `rasterize_mode` is "antialiased", the function will + apply a view-dependent compensation factor + :math:`\\rho=\\sqrt{\\frac{Det(\\Sigma)}{Det(\\Sigma+ \\epsilon I)}}` to Gaussian + opacities, where :math:`\\Sigma` is the projected 2D covariance matrix and :math:`\\epsilon` + is the `eps2d`. This will make the rendered image more antialiased, as proposed in + the paper `Mip-Splatting: Alias-free 3D Gaussian Splatting `_. + + .. note:: + **AbsGrad**: If `absgrad` is True, the absolute gradients of the projected + 2D means will be computed during the backward pass, which could be accessed by + `meta["means2d"].absgrad`. This is an implementation of the paper + `AbsGS: Recovering Fine Details for 3D Gaussian Splatting `_, + which is shown to be more effective for splitting Gaussians during training. + + .. warning:: + This function is currently not differentiable w.r.t. the camera intrinsics `Ks`. + + Args: + means: The 3D centers of the Gaussians. [N, 3] + quats: The quaternions of the Gaussians. It's not required to be normalized. [N, 4] + scales: The scales of the Gaussians. [N, 3] + opacities: The opacities of the Gaussians. [N] + colors: The colors of the Gaussians. [(C,) N, D] or [(C,) N, K, 3] for SH coefficients. + viewmats: The world-to-cam transformation of the cameras. [C, 4, 4] + Ks: The camera intrinsics. [C, 3, 3] + width: The width of the image. + height: The height of the image. + near_plane: The near plane for clipping. Default is 0.01. + far_plane: The far plane for clipping. Default is 1e10. + radius_clip: Gaussians with 2D radius smaller or equal than this value will be + skipped. This is extremely helpful for speeding up large scale scenes. + Default is 0.0. + eps2d: An epsilon added to the egienvalues of projected 2D covariance matrices. + This will prevents the projected GS to be too small. For example eps2d=0.3 + leads to minimal 3 pixel unit. Default is 0.3. + sh_degree: The SH degree to use, which can be smaller than the total + number of bands. If set, the `colors` should be [(C,) N, K, 3] SH coefficients, + else the `colors` should [(C,) N, D] post-activation color values. Default is None. + packed: Whether to use packed mode which is more memory efficient but might or + might not be as fast. Default is True. + tile_size: The size of the tiles for rasterization. Default is 16. + (Note: other values are not tested) + backgrounds: The background colors. [C, D]. Default is None. + render_mode: The rendering mode. Supported modes are "RGB", "D", "ED", "RGB+D", + and "RGB+ED". "RGB" renders the colored image, "D" renders the accumulated depth, and + "ED" renders the expected depth. Default is "RGB". + sparse_grad: If true, the gradients for {means, quats, scales} will be stored in + a COO sparse layout. This can be helpful for saving memory. Default is False. + absgrad: If true, the absolute gradients of the projected 2D means + will be computed during the backward pass, which could be accessed by + `meta["means2d"].absgrad`. Default is False. + rasterize_mode: The rasterization mode. Supported modes are "classic" and + "antialiased". Default is "classic". + channel_chunk: The number of channels to render in one go. Default is 32. + If the required rendering channels are larger than this value, the rendering + will be done looply in chunks. + + Returns: + A tuple: + + **render_colors**: The rendered colors. [C, height, width, X]. + X depends on the `render_mode` and input `colors`. If `render_mode` is "RGB", + X is D; if `render_mode` is "D" or "ED", X is 1; if `render_mode` is "RGB+D" or + "RGB+ED", X is D+1. + + **render_alphas**: The rendered alphas. [C, height, width, 1]. + + **meta**: A dictionary of intermediate results of the rasterization. + + Examples: + + .. code-block:: python + + >>> # define Gaussians + >>> means = torch.randn((100, 3), device=device) + >>> quats = torch.randn((100, 4), device=device) + >>> scales = torch.rand((100, 3), device=device) * 0.1 + >>> colors = torch.rand((100, 3), device=device) + >>> opacities = torch.rand((100,), device=device) + >>> # define cameras + >>> viewmats = torch.eye(4, device=device)[None, :, :] + >>> Ks = torch.tensor([ + >>> [300., 0., 150.], [0., 300., 100.], [0., 0., 1.]], device=device)[None, :, :] + >>> width, height = 300, 200 + >>> # render + >>> colors, alphas, meta = rasterization( + >>> means, quats, scales, opacities, colors, viewmats, Ks, width, height + >>> ) + >>> print (colors.shape, alphas.shape) + torch.Size([1, 200, 300, 3]) torch.Size([1, 200, 300, 1]) + >>> print (meta.keys()) + dict_keys(['camera_ids', 'gaussian_ids', 'radii', 'means2d', 'depths', 'conics', + 'opacities', 'tile_width', 'tile_height', 'tiles_per_gauss', 'isect_ids', + 'flatten_ids', 'isect_offsets', 'width', 'height', 'tile_size']) + + """ + + N = means.shape[0] + C = viewmats.shape[0] + assert means.shape == (N, 3), means.shape + assert quats.shape == (N, 4), quats.shape + assert scales.shape == (N, 3), scales.shape + assert opacities.shape == (N,), opacities.shape + assert viewmats.shape == (C, 4, 4), viewmats.shape + assert Ks.shape == (C, 3, 3), Ks.shape + assert render_mode in ["RGB", "D", "ED", "RGB+D", "RGB+ED"], render_mode + assert packed == False, "Packed mode is not supported yet." + + if sh_degree is None: + # treat colors as post-activation values, should be in shape [N, D] or [C, N, D] + assert (colors.dim() == 2 and colors.shape[0] == N) or ( + colors.dim() == 3 and colors.shape[:2] == (C, N) + ), colors.shape + else: + # treat colors as SH coefficients, should be in shape [N, K, 3] or [C, N, K, 3] + # Allowing for activating partial SH bands + assert ( + colors.dim() == 3 and colors.shape[0] == N and colors.shape[2] == 3 + ) or ( + colors.dim() == 4 and colors.shape[:2] == (C, N) and colors.shape[3] == 3 + ), colors.shape + assert (sh_degree + 1) ** 2 <= colors.shape[-2], colors.shape + + # we don't need to gradient for the projection, it only used for creating tile-based rendering + with torch.no_grad(): + # Project Gaussians to 2D. Directly pass in {quats, scales} is faster than precomputing covars. + proj_results = fully_fused_projection( + means, + None, # covars, + quats, + scales, + viewmats, + Ks, + width, + height, + eps2d=eps2d, + packed=packed, + near_plane=near_plane, + far_plane=far_plane, + radius_clip=radius_clip, + sparse_grad=sparse_grad, + calc_compensations=(rasterize_mode == "antialiased"), + ) + + if packed: + # The results are packed into shape [nnz, ...]. All elements are valid. + ( + camera_ids, + gaussian_ids, + radii, + means2d, + depths, + conics, + compensations, + ) = proj_results + opacities = opacities[gaussian_ids] # [nnz] + else: + # The results are with shape [C, N, ...]. Only the elements with radii > 0 are valid. + radii, means2d, depths, conics, compensations = proj_results + opacities = opacities.repeat(C, 1) # [C, N] + camera_ids, gaussian_ids = None, None + + # if compensations is not None: + # opacities = opacities * compensations + + # Identify intersecting tiles + tile_width = math.ceil(width / float(tile_size)) + tile_height = math.ceil(height / float(tile_size)) + tiles_per_gauss, isect_ids, flatten_ids = isect_tiles( + means2d, + radii, + depths, + tile_size, + tile_width, + tile_height, + packed=packed, + n_cameras=C, + camera_ids=camera_ids, + gaussian_ids=gaussian_ids, + ) + isect_offsets = isect_offset_encode(isect_ids, C, tile_width, tile_height) + + # Turn colors into [C, N, D] or [nnz, D] to pass into rasterize_to_pixels() + if sh_degree is None: + # Colors are post-activation values, with shape [N, D] or [C, N, D] + if packed: + if colors.dim() == 2: + # Turn [N, D] into [nnz, D] + colors = colors[gaussian_ids] + else: + # Turn [C, N, D] into [nnz, D] + colors = colors[camera_ids, gaussian_ids] + else: + if colors.dim() == 2: + # Turn [N, D] into [C, N, D] + colors = colors.expand(C, -1, -1) + else: + # colors is already [C, N, D] + pass + else: + # Colors are SH coefficients, with shape [N, K, 3] or [C, N, K, 3] + camtoworlds = torch.inverse(viewmats) # [C, 4, 4] + if packed: + dirs = means[gaussian_ids, :] - camtoworlds[camera_ids, :3, 3] # [nnz, 3] + masks = radii > 0 # [nnz] + if colors.dim() == 3: + # Turn [N, K, 3] into [nnz, 3] + shs = colors[gaussian_ids, :, :] # [nnz, K, 3] + else: + # Turn [C, N, K, 3] into [nnz, 3] + shs = colors[camera_ids, gaussian_ids, :, :] # [nnz, K, 3] + colors = spherical_harmonics(sh_degree, dirs, shs, masks=masks) # [nnz, 3] + else: + dirs = means[None, :, :] - camtoworlds[:, None, :3, 3] # [C, N, 3] + masks = radii > 0 # [C, N] + if colors.dim() == 3: + # Turn [N, K, 3] into [C, N, 3] + shs = colors.expand(C, -1, -1, -1) # [C, N, K, 3] + else: + # colors is already [C, N, K, 3] + shs = colors + colors = spherical_harmonics(sh_degree, dirs, shs, masks=masks) # [C, N, 3] + # make it apple-to-apple with Inria's CUDA Backend. + colors = torch.clamp_min(colors + 0.5, 0.0) + + # precompute view to gaussian + camtoworlds = torch.linalg.inv(viewmats) # [C, 4, 4] + view2gaussians = view_to_gaussians(means, quats, scales, camtoworlds, radii) + + # hack to retain_grad for means2d as we need it for gaussian densification + if means.requires_grad: + means2d.requires_grad = True + + # Rasterize to pixels + # if render_mode in ["RGB+D", "RGB+ED"]: + # colors = torch.cat((colors, depths[..., None]), dim=-1) + # if backgrounds is not None: + # backgrounds = torch.cat( + # [backgrounds, torch.zeros(C, 1, device=backgrounds.device)], dim=-1 + # ) + # elif render_mode in ["D", "ED"]: + # colors = depths[..., None] + # if backgrounds is not None: + # backgrounds = torch.zeros(C, 1, device=backgrounds.device) + # else: # RGB + # pass + + if colors.shape[-1] > channel_chunk: + # slice into chunks + n_chunks = (colors.shape[-1] + channel_chunk - 1) // channel_chunk + render_colors, render_alphas = [], [] + for i in range(n_chunks): + colors_chunk = colors[..., i * channel_chunk : (i + 1) * channel_chunk] + backgrounds_chunk = ( + backgrounds[..., i * channel_chunk : (i + 1) * channel_chunk] + if backgrounds is not None + else None + ) + render_colors_, render_alphas_ = raytracing_to_pixels( + means2d, + conics, + colors_chunk, + opacities, + view2gaussians, + Ks, + width, + height, + tile_size, + isect_offsets, + flatten_ids, + backgrounds=backgrounds_chunk, + packed=packed, + absgrad=absgrad, + ) + render_colors.append(render_colors_) + render_alphas.append(render_alphas_) + render_colors = torch.cat(render_colors, dim=-1) + render_alphas = render_alphas[0] # discard the rest + else: + render_colors, render_alphas = raytracing_to_pixels( + means2d, + conics, + colors, + opacities, + view2gaussians, + Ks, + width, + height, + tile_size, + isect_offsets, + flatten_ids, + backgrounds=backgrounds, + packed=packed, + absgrad=absgrad, + ) + + # if render_mode in ["ED", "RGB+ED"]: + # # normalize the accumulated depth to get the expected depth + # render_colors = torch.cat( + # [ + # render_colors[..., :-1], + # render_colors[..., -1:] / render_alphas.clamp(min=1e-10), + # ], + # dim=-1, + # ) + + meta = { + "camera_ids": camera_ids, + "gaussian_ids": gaussian_ids, + "radii": radii, + "means2d": means2d, + "depths": depths, + "conics": conics, + "opacities": opacities, + "tile_width": tile_width, + "tile_height": tile_height, + "tiles_per_gauss": tiles_per_gauss, + "isect_ids": isect_ids, + "flatten_ids": flatten_ids, + "isect_offsets": isect_offsets, + "width": width, + "height": height, + "tile_size": tile_size, + } + return render_colors, render_alphas, meta + +@torch.no_grad() +def integration( + points: Tensor, # [N, 3] + means: Tensor, # [N, 3] + quats: Tensor, # [N, 4] + scales: Tensor, # [N, 3] + opacities: Tensor, # [N] + colors: Tensor, # [(C,) N, D] or [(C,) N, K, 3] + viewmats: Tensor, # [C, 4, 4] + Ks: Tensor, # [C, 3, 3] + width: int, + height: int, + near_plane: float = 0.01, + far_plane: float = 1e10, + radius_clip: float = 0.0, + eps2d: float = 0.3, + sh_degree: Optional[int] = None, + packed: bool = False, + tile_size: int = 16, + backgrounds: Optional[Tensor] = None, + render_mode: Literal["RGB", "D", "ED", "RGB+D", "RGB+ED"] = "RGB", + sparse_grad: bool = False, + absgrad: bool = False, + rasterize_mode: Literal["classic", "antialiased"] = "classic", +) -> Tuple[Tensor, Tensor, Dict]: + """evaluating Gaussian Opacity Fields from a set of 3D Gaussians (N) to a batch of image planes (C) using. + + TODO: update the doc to integration. + This function provides a handful features for 3D Gaussian rasterization, which + we detail in the following notes. A complete profiling of the these features + can be found in the :ref:`profiling` page. + + .. note:: + **Batch Rasterization**: This function allows for rasterizing a set of 3D Gaussians + to a batch of images in one go, by simplly providing the batched `viewmats` and `Ks`. + + .. note:: + **Support N-D Features**: If `sh_degree` is None, + the `colors` is expected to be with shape [N, D] or [C, N, D], in which D is the channel of + the features to be rendered. The computation is slow when D > 32 at the moment. + If `sh_degree` is set, the `colors` is expected to be the SH coefficients with + shape [N, K, 3] or [C, N, K, 3], where K is the number of SH bases. In this case, it is expected + that :math:`(\\textit{sh_degree} + 1) ^ 2 \\leq K`, where `sh_degree` controls the + activated bases in the SH coefficients. + + .. note:: + **Depth Rendering**: This function supports colors or/and depths via `render_mode`. + The supported modes are "RGB", "D", "ED", "RGB+D", and "RGB+ED". "RGB" renders the + colored image that respects the `colors` argument. "D" renders the accumulated z-depth + :math:`\\sum_i w_i z_i`. "ED" renders the expected z-depth + :math:`\\frac{\\sum_i w_i z_i}{\\sum_i w_i}`. "RGB+D" and "RGB+ED" render both + the colored image and the depth, in which the depth is the last channel of the output. + + .. note:: + **Memory-Speed Trade-off**: The `packed` argument provides a trade-off between + memory footprint and runtime. If `packed` is True, the intermediate results are + packed into sparse tensors, which is more memory efficient but might be slightly + slower. This is especially helpful when the scene is large and each camera sees only + a small portion of the scene. If `packed` is False, the intermediate results are + with shape [C, N, ...], which is faster but might consume more memory. + + .. note:: + **Sparse Gradients**: If `sparse_grad` is True, the gradients for {means, quats, scales} + will be stored in a `COO sparse layout `_. + This can be helpful for saving memory + for training when the scene is large and each iteration only activates a small portion + of the Gaussians. Usually a sparse optimizer is required to work with sparse gradients, + such as `torch.optim.SparseAdam `_. + This argument is only effective when `packed` is True. + + .. note:: + **Speed-up for Large Scenes**: The `radius_clip` argument is extremely helpful for + speeding up large scale scenes or scenes with large depth of fields. Gaussians with + 2D radius smaller or equal than this value (in pixel unit) will be skipped during rasterization. + This will skip all the far-away Gaussians that are too small to be seen in the image. + But be warned that if there are close-up Gaussians that are also below this threshold, they will + also get skipped (which is rarely happened in practice). This is by default disabled by setting + `radius_clip` to 0.0. + + .. note:: + **Antialiased Rendering**: If `rasterize_mode` is "antialiased", the function will + apply a view-dependent compensation factor + :math:`\\rho=\\sqrt{\\frac{Det(\\Sigma)}{Det(\\Sigma+ \\epsilon I)}}` to Gaussian + opacities, where :math:`\\Sigma` is the projected 2D covariance matrix and :math:`\\epsilon` + is the `eps2d`. This will make the rendered image more antialiased, as proposed in + the paper `Mip-Splatting: Alias-free 3D Gaussian Splatting `_. + + .. note:: + **AbsGrad**: If `absgrad` is True, the absolute gradients of the projected + 2D means will be computed during the backward pass, which could be accessed by + `meta["means2d"].absgrad`. This is an implementation of the paper + `AbsGS: Recovering Fine Details for 3D Gaussian Splatting `_, + which is shown to be more effective for splitting Gaussians during training. + + .. warning:: + This function is currently not differentiable w.r.t. the camera intrinsics `Ks`. + + Args: + means: The 3D centers of the Gaussians. [N, 3] + quats: The quaternions of the Gaussians. It's not required to be normalized. [N, 4] + scales: The scales of the Gaussians. [N, 3] + opacities: The opacities of the Gaussians. [N] + colors: The colors of the Gaussians. [(C,) N, D] or [(C,) N, K, 3] for SH coefficients. + viewmats: The world-to-cam transformation of the cameras. [C, 4, 4] + Ks: The camera intrinsics. [C, 3, 3] + width: The width of the image. + height: The height of the image. + near_plane: The near plane for clipping. Default is 0.01. + far_plane: The far plane for clipping. Default is 1e10. + radius_clip: Gaussians with 2D radius smaller or equal than this value will be + skipped. This is extremely helpful for speeding up large scale scenes. + Default is 0.0. + eps2d: An epsilon added to the egienvalues of projected 2D covariance matrices. + This will prevents the projected GS to be too small. For example eps2d=0.3 + leads to minimal 3 pixel unit. Default is 0.3. + sh_degree: The SH degree to use, which can be smaller than the total + number of bands. If set, the `colors` should be [(C,) N, K, 3] SH coefficients, + else the `colors` should [(C,) N, D] post-activation color values. Default is None. + packed: Whether to use packed mode which is more memory efficient but might or + might not be as fast. Default is True. + tile_size: The size of the tiles for rasterization. Default is 16. + (Note: other values are not tested) + backgrounds: The background colors. [C, D]. Default is None. + render_mode: The rendering mode. Supported modes are "RGB", "D", "ED", "RGB+D", + and "RGB+ED". "RGB" renders the colored image, "D" renders the accumulated depth, and + "ED" renders the expected depth. Default is "RGB". + sparse_grad: If true, the gradients for {means, quats, scales} will be stored in + a COO sparse layout. This can be helpful for saving memory. Default is False. + absgrad: If true, the absolute gradients of the projected 2D means + will be computed during the backward pass, which could be accessed by + `meta["means2d"].absgrad`. Default is False. + rasterize_mode: The rasterization mode. Supported modes are "classic" and + "antialiased". Default is "classic". + channel_chunk: The number of channels to render in one go. Default is 32. + If the required rendering channels are larger than this value, the rendering + will be done looply in chunks. + + Returns: + A tuple: + + **render_colors**: The rendered colors. [C, height, width, X]. + X depends on the `render_mode` and input `colors`. If `render_mode` is "RGB", + X is D; if `render_mode` is "D" or "ED", X is 1; if `render_mode` is "RGB+D" or + "RGB+ED", X is D+1. + + **render_alphas**: The rendered alphas. [C, height, width, 1]. + + **meta**: A dictionary of intermediate results of the rasterization. + + Examples: + + .. code-block:: python + + >>> # define Gaussians + >>> means = torch.randn((100, 3), device=device) + >>> quats = torch.randn((100, 4), device=device) + >>> scales = torch.rand((100, 3), device=device) * 0.1 + >>> colors = torch.rand((100, 3), device=device) + >>> opacities = torch.rand((100,), device=device) + >>> # define cameras + >>> viewmats = torch.eye(4, device=device)[None, :, :] + >>> Ks = torch.tensor([ + >>> [300., 0., 150.], [0., 300., 100.], [0., 0., 1.]], device=device)[None, :, :] + >>> width, height = 300, 200 + >>> # render + >>> colors, alphas, meta = rasterization( + >>> means, quats, scales, opacities, colors, viewmats, Ks, width, height + >>> ) + >>> print (colors.shape, alphas.shape) + torch.Size([1, 200, 300, 3]) torch.Size([1, 200, 300, 1]) + >>> print (meta.keys()) + dict_keys(['camera_ids', 'gaussian_ids', 'radii', 'means2d', 'depths', 'conics', + 'opacities', 'tile_width', 'tile_height', 'tiles_per_gauss', 'isect_ids', + 'flatten_ids', 'isect_offsets', 'width', 'height', 'tile_size']) + + """ + + N = means.shape[0] + C = viewmats.shape[0] + assert means.shape == (N, 3), means.shape + assert quats.shape == (N, 4), quats.shape + assert scales.shape == (N, 3), scales.shape + assert opacities.shape == (N,), opacities.shape + assert viewmats.shape == (C, 4, 4), viewmats.shape + assert Ks.shape == (C, 3, 3), Ks.shape + assert render_mode in ["RGB", "D", "ED", "RGB+D", "RGB+ED"], render_mode + assert packed == False, "Packed mode is not supported yet." + + if sh_degree is None: + # treat colors as post-activation values, should be in shape [N, D] or [C, N, D] + assert (colors.dim() == 2 and colors.shape[0] == N) or ( + colors.dim() == 3 and colors.shape[:2] == (C, N) + ), colors.shape + else: + # treat colors as SH coefficients, should be in shape [N, K, 3] or [C, N, K, 3] + # Allowing for activating partial SH bands + assert ( + colors.dim() == 3 and colors.shape[0] == N and colors.shape[2] == 3 + ) or ( + colors.dim() == 4 and colors.shape[:2] == (C, N) and colors.shape[3] == 3 + ), colors.shape + assert (sh_degree + 1) ** 2 <= colors.shape[-2], colors.shape + + # Project Gaussians to 2D. Directly pass in {quats, scales} is faster than precomputing covars. + proj_results = fully_fused_projection( + means, + None, # covars, + quats, + scales, + viewmats, + Ks, + width, + height, + eps2d=eps2d, + packed=packed, + near_plane=near_plane, + far_plane=far_plane, + radius_clip=radius_clip, + sparse_grad=sparse_grad, + calc_compensations=(rasterize_mode == "antialiased"), + ) + + if packed: + # The results are packed into shape [nnz, ...]. All elements are valid. + ( + camera_ids, + gaussian_ids, + radii, + means2d, + depths, + conics, + compensations, + ) = proj_results + opacities = opacities[gaussian_ids] # [nnz] + else: + # The results are with shape [C, N, ...]. Only the elements with radii > 0 are valid. + radii, means2d, depths, conics, compensations = proj_results + opacities = opacities.repeat(C, 1) # [C, N] + camera_ids, gaussian_ids = None, None + + # if compensations is not None: + # opacities = opacities * compensations + + # Identify intersecting tiles + tile_width = math.ceil(width / float(tile_size)) + tile_height = math.ceil(height / float(tile_size)) + tiles_per_gauss, isect_ids, flatten_ids = isect_tiles( + means2d, + radii, + depths, + tile_size, + tile_width, + tile_height, + packed=packed, + n_cameras=C, + camera_ids=camera_ids, + gaussian_ids=gaussian_ids, + ) + isect_offsets = isect_offset_encode(isect_ids, C, tile_width, tile_height) + + # create tiles and isect_offsets for input points + points_radii, points2d, point_depths = project_points( + points, + viewmats, + Ks, + width, + height, + near_plane=near_plane, + far_plane=far_plane, + ) + + point_isect_ids, point_flatten_ids = points_isect_tiles( + points2d, + points_radii, + point_depths, + tile_size, + tile_width, + tile_height, + packed=packed, + n_cameras=C, + camera_ids=camera_ids, + gaussian_ids=gaussian_ids, + ) + point_isect_offsets = isect_offset_encode(point_isect_ids, C, tile_width, tile_height) + + # Turn colors into [C, N, D] or [nnz, D] to pass into rasterize_to_pixels() + if sh_degree is None: + # Colors are post-activation values, with shape [N, D] or [C, N, D] + if packed: + if colors.dim() == 2: + # Turn [N, D] into [nnz, D] + colors = colors[gaussian_ids] + else: + # Turn [C, N, D] into [nnz, D] + colors = colors[camera_ids, gaussian_ids] + else: + if colors.dim() == 2: + # Turn [N, D] into [C, N, D] + colors = colors.expand(C, -1, -1) + else: + # colors is already [C, N, D] + pass + else: + # Colors are SH coefficients, with shape [N, K, 3] or [C, N, K, 3] + camtoworlds = torch.inverse(viewmats) # [C, 4, 4] + if packed: + dirs = means[gaussian_ids, :] - camtoworlds[camera_ids, :3, 3] # [nnz, 3] + masks = radii > 0 # [nnz] + if colors.dim() == 3: + # Turn [N, K, 3] into [nnz, 3] + shs = colors[gaussian_ids, :, :] # [nnz, K, 3] + else: + # Turn [C, N, K, 3] into [nnz, 3] + shs = colors[camera_ids, gaussian_ids, :, :] # [nnz, K, 3] + colors = spherical_harmonics(sh_degree, dirs, shs, masks=masks) # [nnz, 3] + else: + dirs = means[None, :, :] - camtoworlds[:, None, :3, 3] # [C, N, 3] + masks = radii > 0 # [C, N] + if colors.dim() == 3: + # Turn [N, K, 3] into [C, N, 3] + shs = colors.expand(C, -1, -1, -1) # [C, N, K, 3] + else: + # colors is already [C, N, K, 3] + shs = colors + colors = spherical_harmonics(sh_degree, dirs, shs, masks=masks) # [C, N, 3] + # make it apple-to-apple with Inria's CUDA Backend. + colors = torch.clamp_min(colors + 0.5, 0.0) + + # precompute view to gaussian + camtoworlds = torch.linalg.inv(viewmats) # [C, 4, 4] + view2gaussians = view_to_gaussians(means, quats, scales, camtoworlds, radii) + + integrated_colors, integrated_alphas = integrate_to_points( + points2d, + point_depths, + means2d, + conics, + colors, + opacities, + view2gaussians, + Ks, + width, + height, + tile_size, + isect_offsets, + flatten_ids, + point_isect_offsets, + point_flatten_ids, + backgrounds=backgrounds, + packed=packed, + absgrad=absgrad, + ) + # breakpoint() + + meta = { + "camera_ids": camera_ids, + "gaussian_ids": gaussian_ids, + "radii": radii, + "means2d": means2d, + "depths": depths, + "conics": conics, + "opacities": opacities, + "tile_width": tile_width, + "tile_height": tile_height, + "tiles_per_gauss": tiles_per_gauss, + "isect_ids": isect_ids, + "flatten_ids": flatten_ids, + "isect_offsets": isect_offsets, + "width": width, + "height": height, + "tile_size": tile_size, + } + return integrated_colors, integrated_alphas, meta diff --git a/gsplat/tetmesh.py b/gsplat/tetmesh.py new file mode 100644 index 000000000..e714ce706 --- /dev/null +++ b/gsplat/tetmesh.py @@ -0,0 +1,190 @@ +### copy from https://raw.githubusercontent.com/autonomousvision/gaussian-opacity-fields/main/utils/tetmesh.py + +# Copyright (c) 2021,22 NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch + +__all__ = ['marching_tetrahedra'] + +triangle_table = torch.tensor([ + [-1, -1, -1, -1, -1, -1], + [1, 0, 2, -1, -1, -1], + [4, 0, 3, -1, -1, -1], + [1, 4, 2, 1, 3, 4], + [3, 1, 5, -1, -1, -1], + [2, 3, 0, 2, 5, 3], + [1, 4, 0, 1, 5, 4], + [4, 2, 5, -1, -1, -1], + [4, 5, 2, -1, -1, -1], + [4, 1, 0, 4, 5, 1], + [3, 2, 0, 3, 5, 2], + [1, 3, 5, -1, -1, -1], + [4, 1, 2, 4, 3, 1], + [3, 0, 4, -1, -1, -1], + [2, 0, 1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1] +], dtype=torch.long) + +num_triangles_table = torch.tensor([0, 1, 1, 2, 1, 2, 2, 1, 1, 2, 2, 1, 2, 1, 1, 0], dtype=torch.long) +base_tet_edges = torch.tensor([0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3], dtype=torch.long) +v_id = torch.pow(2, torch.arange(4, dtype=torch.long)) + + +def _unbatched_marching_tetrahedra(vertices, tets, sdf, scales): + """unbatched marching tetrahedra. + + Refer to :func:`marching_tetrahedra`. + """ + device = vertices.device + + # call by chunk + chunk_size = 32 * 1024 * 1024 + if tets.shape[0] > chunk_size: + merged_verts = None + merged_scales = None + merged_faces = None + merged_verts_ids = None + for tet_chunk in torch.chunk(tets, tets.shape[0] // chunk_size + 1): + torch.cuda.empty_cache() + verts, verts_scales, faces, verts_ids = _unbatched_marching_tetrahedra(vertices, tet_chunk, sdf, scales) + + if merged_verts is None: + merged_verts = verts + merged_scales = verts_scales + merged_faces = faces + merged_verts_ids = verts_ids + else: + all_edges = torch.cat([merged_verts_ids, verts_ids], dim=0) + unique_edges, idx_map = torch.unique(all_edges, dim=0, return_inverse=True) + # merge vertices + unique_verts_0 = torch.zeros((unique_edges.shape[0], 2, 3), dtype=torch.float, device=device) + unique_verts_1 = torch.zeros((unique_edges.shape[0], 2, 1), dtype=torch.float, device=device) + unique_verts_0[idx_map[:merged_verts[0].shape[0]]] = merged_verts[0] + unique_verts_0[idx_map[merged_verts[0].shape[0]:]] = verts[0] + unique_verts_1[idx_map[:merged_verts[1].shape[0]]] = merged_verts[1] + unique_verts_1[idx_map[merged_verts[1].shape[0]:]] = verts[1] + # merge scales + unique_scales = torch.zeros((unique_edges.shape[0], 2, 1), dtype=torch.float, device=device) + unique_scales[idx_map[:merged_verts[0].shape[0]]] = merged_scales + unique_scales[idx_map[merged_verts[0].shape[0]:]] = verts_scales + + # merge faces + unique_faces_0 = idx_map[merged_faces.reshape(-1)].reshape(-1, 3) + unique_faces_1 = idx_map[faces.reshape(-1) + merged_verts[0].shape[0]].reshape(-1, 3) + + merged_faces = torch.cat([unique_faces_0, unique_faces_1], dim=0) + merged_verts = (unique_verts_0, unique_verts_1) + merged_scales = unique_scales + merged_verts_ids = unique_edges + torch.cuda.empty_cache() + + return merged_verts, merged_scales, merged_faces, merged_verts_ids + + with torch.no_grad(): + occ_n = sdf > 0 + occ_fx4 = occ_n[tets.reshape(-1)].reshape(-1, 4) + occ_sum = torch.sum(occ_fx4, -1) + + valid_tets = (occ_sum > 0) & (occ_sum < 4) + + # find all vertices + all_edges = tets[valid_tets][:, base_tet_edges.to(device)].reshape(-1, 2) + + order = (all_edges[:, 0] > all_edges[:, 1]).bool() + all_edges[order] = all_edges[order][:, [1, 0]] + + unique_edges, idx_map = torch.unique(all_edges, dim=0, return_inverse=True) + + unique_edges = unique_edges.long() + mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1 + mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=device) * -1 + mapping[mask_edges] = torch.arange(mask_edges.sum(), dtype=torch.long, device=device) + idx_map = mapping[idx_map] + + interp_v = unique_edges[mask_edges] + edges_to_interp = vertices[interp_v.reshape(-1)].reshape(-1, 2, 3) + edges_to_interp_sdf = sdf[interp_v.reshape(-1)].reshape(-1, 2, 1) + verts_scales = scales[interp_v.reshape(-1)].reshape(-1, 2, 1) + + verts = (edges_to_interp, edges_to_interp_sdf) + idx_map = idx_map.reshape(-1, 6) + + tetindex = (occ_fx4[valid_tets] * v_id.to(device).unsqueeze(0)).sum(-1) + num_triangles = num_triangles_table.to(device)[tetindex] + triangle_table_device = triangle_table.to(device) + + # Generate triangle indices + faces = torch.cat(( + torch.gather(input=idx_map[num_triangles == 1], dim=1, + index=triangle_table_device[tetindex[num_triangles == 1]][:, :3]).reshape(-1, 3), + torch.gather(input=idx_map[num_triangles == 2], dim=1, + index=triangle_table_device[tetindex[num_triangles == 2]][:, :6]).reshape(-1, 3), + ), dim=0) + + return verts, verts_scales, faces, interp_v + + +def marching_tetrahedra(vertices, tets, sdf, scales): + r"""Convert discrete signed distance fields encoded on tetrahedral grids to triangle + meshes using marching tetrahedra algorithm as described in `An efficient method of + triangulating equi-valued surfaces by using tetrahedral cells`_. The output surface is differentiable with respect to + input vertex positions and the SDF values. For more details and example usage in learning, see + `Deep Marching Tetrahedra\: a Hybrid Representation for High-Resolution 3D Shape Synthesis`_ NeurIPS 2021. + + + Args: + vertices (torch.tensor): batched vertices of tetrahedral meshes, of shape + :math:`(\text{batch_size}, \text{num_vertices}, 3)`. + tets (torch.tensor): unbatched tetrahedral mesh topology, of shape + :math:`(\text{num_tetrahedrons}, 4)`. + sdf (torch.tensor): batched SDFs which specify the SDF value of each vertex, of shape + :math:`(\text{batch_size}, \text{num_vertices})`. + + Returns: + (list[torch.Tensor], list[torch.LongTensor], (optional) list[torch.LongTensor]): + + - the list of vertices for mesh converted from each tetrahedral grid. + - the list of faces for mesh converted from each tetrahedral grid. + + Example: + >>> vertices = torch.tensor([[[0, 0, 0], + ... [1, 0, 0], + ... [0, 1, 0], + ... [0, 0, 1]]], dtype=torch.float) + >>> tets = torch.tensor([[0, 1, 2, 3]], dtype=torch.long) + >>> sdf = torch.tensor([[-1., -1., 0.5, 0.5]], dtype=torch.float) + >>> verts_list, faces_list, tet_idx_list = marching_tetrahedra(vertices, tets, sdf, True) + >>> verts_list[0] + tensor([[0.0000, 0.6667, 0.0000], + [0.0000, 0.0000, 0.6667], + [0.3333, 0.6667, 0.0000], + [0.3333, 0.0000, 0.6667]]) + >>> faces_list[0] + tensor([[3, 0, 1], + [3, 2, 0]]) + >>> tet_idx_list[0] + tensor([0, 0]) + + .. _An efficient method of triangulating equi-valued surfaces by using tetrahedral cells: + https://search.ieice.org/bin/summary.php?id=e74-d_1_214 + + .. _Deep Marching Tetrahedra\: a Hybrid Representation for High-Resolution 3D Shape Synthesis: + https://arxiv.org/abs/2111.04276 + """ + + list_of_outputs = [_unbatched_marching_tetrahedra(vertices[b], tets, sdf[b], scales[b]) for b in range(vertices.shape[0])] + return list(zip(*list_of_outputs)) \ No newline at end of file diff --git a/tests/test_basic.py b/tests/test_basic.py index 8c546b450..8830d431f 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -571,3 +571,56 @@ def test_sh(test_data, sh_degree: int): torch.testing.assert_close(v_coeffs, _v_coeffs, rtol=1e-4, atol=1e-4) if sh_degree > 0: torch.testing.assert_close(v_dirs, _v_dirs, rtol=1e-4, atol=1e-4) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device") +def test_view_to_gaussians(): + # from gsplat.cuda._torch_impl import _quat_scale_to_covar_preci + from gsplat.cuda._wrapper import view_to_gaussians + from gsplat.rendering import _view_to_gaussians + + torch.set_float32_matmul_precision("highest") + torch.manual_seed(42) + + means = torch.load("../assets/means.pt") + quats = torch.load("../assets/quats.pt") + scales = torch.load("../assets/scales.pt") + radii = torch.load("../assets/radii.pt") + camtoworlds = torch.load("../assets/camtoworlds.pt") + + means.requires_grad = True + quats.requires_grad = True + scales.requires_grad = True + radii.requires_grad = False + camtoworlds.requires_grad = False + + # to double for numerical precision + means = means.double() + quats = quats.double() + scales = scales.double() + # radii = radii.double() + camtoworlds = camtoworlds.double() + + # forward + # plus 10 to compute view2gaussians for all inputs + radii += 10 + view2gaussians = view_to_gaussians(means, quats, scales, camtoworlds, radii) + _view2gaussians = _view_to_gaussians(means, quats, scales, camtoworlds) + torch.testing.assert_close(view2gaussians[:, :10], _view2gaussians[:, :10]) + + # backward + v_view2gaussians = torch.randn_like(view2gaussians) + v_means, v_quats, v_scales = torch.autograd.grad( + outputs=view2gaussians, + inputs=(means, quats, scales), + grad_outputs=v_view2gaussians, + ) + + _v_means, _v_quats, _v_scales = torch.autograd.grad( + outputs=_view2gaussians, + inputs=(means, quats, scales), + grad_outputs=v_view2gaussians, + ) + torch.testing.assert_close(v_means, _v_means, rtol=1e-1, atol=1e-1) + torch.testing.assert_close(v_quats, _v_quats, rtol=1e-1, atol=1e-1) + torch.testing.assert_close(v_scales, _v_scales, rtol=1e-1, atol=1e-1)