diff --git a/gsplat/_torch_impl.py b/gsplat/_torch_impl.py index 8fe5951c6..760599a8e 100644 --- a/gsplat/_torch_impl.py +++ b/gsplat/_torch_impl.py @@ -5,6 +5,7 @@ import torch.nn.functional as F from jaxtyping import Float from torch import Tensor +from typing import Tuple def compute_sh_color( @@ -116,15 +117,15 @@ def quat_to_rotmat(quat: Tensor) -> Tensor: w, x, y, z = torch.unbind(F.normalize(quat, dim=-1), dim=-1) mat = torch.stack( [ - 1 - 2 * (y ** 2 + z ** 2), + 1 - 2 * (y**2 + z**2), 2 * (x * y - w * z), 2 * (x * z + w * y), 2 * (x * y + w * z), - 1 - 2 * (x ** 2 + z ** 2), + 1 - 2 * (x**2 + z**2), 2 * (y * z - w * x), 2 * (x * z - w * y), 2 * (y * z + w * x), - 1 - 2 * (x ** 2 + y ** 2), + 1 - 2 * (x**2 + y**2), ], dim=-1, ) @@ -149,7 +150,7 @@ def project_cov3d_ewa( fy: float, tan_fovx: float, tan_fovy: float, -) -> Tensor: +) -> Tuple[Tensor, Tensor]: assert mean3d.shape[-1] == 3, mean3d.shape assert cov3d.shape[-2:] == (3, 3), cov3d.shape assert viewmat.shape[-2:] == (4, 4), viewmat.shape @@ -158,7 +159,7 @@ def project_cov3d_ewa( t = torch.einsum("...ij,...j->...i", W, mean3d) + p # (..., 3) rz = 1.0 / t[..., 2] # (...,) - rz2 = rz ** 2 # (...,) + rz2 = rz**2 # (...,) lim_x = 1.3 * torch.tensor([tan_fovx], device=mean3d.device) lim_y = 1.3 * torch.tensor([tan_fovy], device=mean3d.device) @@ -174,9 +175,12 @@ def project_cov3d_ewa( T = torch.matmul(J, W) # (..., 2, 3) cov2d = torch.einsum("...ij,...jk,...kl->...il", T, cov3d, T.transpose(-1, -2)) # add a little blur along axes and (TODO save upper triangular elements) + det_orig = cov2d[..., 0, 0] * cov2d[..., 1, 1] - cov2d[..., 0, 1] * cov2d[..., 0, 1] cov2d[..., 0, 0] = cov2d[..., 0, 0] + 0.3 cov2d[..., 1, 1] = cov2d[..., 1, 1] + 0.3 - return cov2d[..., :2, :2] + det_blur = cov2d[..., 0, 0] * cov2d[..., 1, 1] - cov2d[..., 0, 1] * cov2d[..., 0, 1] + compensation = torch.sqrt(torch.clamp(det_orig / det_blur, min=0)) + return cov2d[..., :2, :2], compensation.detach() def compute_cov2d_bounds(cov2d_mat: Tensor): @@ -198,8 +202,8 @@ def compute_cov2d_bounds(cov2d_mat: Tensor): dim=-1, ) # (..., 3) b = (cov2d[..., 0, 0] + cov2d[..., 1, 1]) / 2 # (...,) - v1 = b + torch.sqrt(torch.clamp(b ** 2 - det, min=0.1)) # (...,) - v2 = b - torch.sqrt(torch.clamp(b ** 2 - det, min=0.1)) # (...,) + v1 = b + torch.sqrt(torch.clamp(b**2 - det, min=0.1)) # (...,) + v2 = b - torch.sqrt(torch.clamp(b**2 - det, min=0.1)) # (...,) radius = torch.ceil(3.0 * torch.sqrt(torch.max(v1, v2))) # (...,) radius_all = torch.zeros(*cov2d_mat.shape[:-2], device=cov2d_mat.device) conic_all = torch.zeros(*cov2d_mat.shape[:-2], 3, device=cov2d_mat.device) @@ -272,7 +276,9 @@ def project_gaussians_forward( tan_fovy = 0.5 * img_size[1] / fy p_view, is_close = clip_near_plane(means3d, viewmat, clip_thresh) cov3d = scale_rot_to_cov3d(scales, glob_scale, quats) - cov2d = project_cov3d_ewa(means3d, cov3d, viewmat, fx, fy, tan_fovx, tan_fovy) + cov2d, compensation = project_cov3d_ewa( + means3d, cov3d, viewmat, fx, fy, tan_fovx, tan_fovy + ) conic, radius, det_valid = compute_cov2d_bounds(cov2d) xys = project_pix(fullmat, means3d, img_size, (cx, cy)) tile_min, tile_max = get_tile_bbox(xys, radius, tile_bounds) @@ -290,6 +296,7 @@ def project_gaussians_forward( xys = torch.where(~mask[..., None], 0, xys) cov3d = torch.where(~mask[..., None, None], 0, cov3d) cov2d = torch.where(~mask[..., None, None], 0, cov2d) + compensation = torch.where(~mask, 0, compensation) num_tiles_hit = torch.where(~mask, 0, num_tiles_hit) depths = torch.where(~mask, 0, depths) @@ -297,7 +304,17 @@ def project_gaussians_forward( cov3d_triu = cov3d[..., i, j] i, j = torch.triu_indices(2, 2) cov2d_triu = cov2d[..., i, j] - return cov3d_triu, cov2d_triu, xys, depths, radii, conic, num_tiles_hit, mask + return ( + cov3d_triu, + cov2d_triu, + xys, + depths, + radii, + conic, + compensation, + num_tiles_hit, + mask, + ) def map_gaussian_to_intersects( @@ -339,7 +356,6 @@ def get_tile_bin_edges(num_intersects, isect_ids_sorted, tile_bounds): ) for idx in range(num_intersects): - cur_tile_idx = isect_ids_sorted[idx] >> 32 if idx == 0: diff --git a/gsplat/cuda/csrc/bindings.cu b/gsplat/cuda/csrc/bindings.cu index 6ade94cf3..9b4894209 100644 --- a/gsplat/cuda/csrc/bindings.cu +++ b/gsplat/cuda/csrc/bindings.cu @@ -122,6 +122,7 @@ std::tuple< torch::Tensor, torch::Tensor, torch::Tensor, + torch::Tensor, torch::Tensor> project_gaussians_forward_tensor( const int num_points, @@ -162,6 +163,8 @@ project_gaussians_forward_tensor( torch::zeros({num_points}, means3d.options().dtype(torch::kInt32)); torch::Tensor conics_d = torch::zeros({num_points, 3}, means3d.options().dtype(torch::kFloat32)); + torch::Tensor compensation_d = + torch::zeros({num_points}, means3d.options().dtype(torch::kFloat32)); torch::Tensor num_tiles_hit_d = torch::zeros({num_points}, means3d.options().dtype(torch::kInt32)); @@ -185,11 +188,12 @@ project_gaussians_forward_tensor( depths_d.contiguous().data_ptr(), radii_d.contiguous().data_ptr(), (float3 *)conics_d.contiguous().data_ptr(), + compensation_d.contiguous().data_ptr(), num_tiles_hit_d.contiguous().data_ptr() ); return std::make_tuple( - cov3d_d, xys_d, depths_d, radii_d, conics_d, num_tiles_hit_d + cov3d_d, xys_d, depths_d, radii_d, conics_d, compensation_d, num_tiles_hit_d ); } diff --git a/gsplat/cuda/csrc/bindings.h b/gsplat/cuda/csrc/bindings.h index d29a0b0b1..09ea16599 100644 --- a/gsplat/cuda/csrc/bindings.h +++ b/gsplat/cuda/csrc/bindings.h @@ -40,6 +40,7 @@ std::tuple< torch::Tensor, torch::Tensor, torch::Tensor, + torch::Tensor, torch::Tensor> project_gaussians_forward_tensor( const int num_points, diff --git a/gsplat/cuda/csrc/forward.cu b/gsplat/cuda/csrc/forward.cu index 1f4495b3c..522ed36d1 100644 --- a/gsplat/cuda/csrc/forward.cu +++ b/gsplat/cuda/csrc/forward.cu @@ -26,6 +26,7 @@ __global__ void project_gaussians_forward_kernel( float* __restrict__ depths, int* __restrict__ radii, float3* __restrict__ conics, + float* __restrict__ compensation, int32_t* __restrict__ num_tiles_hit ) { unsigned idx = cg::this_grid().thread_rank(); // idx of thread within grid @@ -61,8 +62,11 @@ __global__ void project_gaussians_forward_kernel( float cy = intrins.w; float tan_fovx = 0.5 * img_size.x / fx; float tan_fovy = 0.5 * img_size.y / fy; - float3 cov2d = project_cov3d_ewa( - p_world, cur_cov3d, viewmat, fx, fy, tan_fovx, tan_fovy + float3 cov2d; + float comp; + project_cov3d_ewa( + p_world, cur_cov3d, viewmat, fx, fy, tan_fovx, tan_fovy, + cov2d, comp ); // printf("cov2d %d, %.2f %.2f %.2f\n", idx, cov2d.x, cov2d.y, cov2d.z); @@ -88,6 +92,7 @@ __global__ void project_gaussians_forward_kernel( depths[idx] = p_view.z; radii[idx] = (int)radius; xys[idx] = center; + compensation[idx] = comp; // printf( // "point %d x %.2f y %.2f z %.2f, radius %d, # tiles %d, tile_min %d // %d, tile_max %d %d\n", idx, center.x, center.y, depths[idx], @@ -372,14 +377,16 @@ __global__ void rasterize_forward( } // device helper to approximate projected 2d cov from 3d mean and cov -__device__ float3 project_cov3d_ewa( +__device__ void project_cov3d_ewa( const float3& __restrict__ mean3d, const float* __restrict__ cov3d, const float* __restrict__ viewmat, const float fx, const float fy, const float tan_fovx, - const float tan_fovy + const float tan_fovy, + float3 &cov2d, + float &compensation ) { // clip the // we expect row major matrices as input, glm uses column major @@ -437,7 +444,14 @@ __device__ float3 project_cov3d_ewa( glm::mat3 cov = T * V * glm::transpose(T); // add a little blur along axes and save upper triangular elements - return make_float3(float(cov[0][0]) + 0.3f, float(cov[0][1]), float(cov[1][1]) + 0.3f); + // and compute the density compensation factor due to the blurs + float c00 = cov[0][0], c11 = cov[1][1], c01 = cov[0][1]; + float det_orig = c00 * c11 - c01 * c01; + cov2d.x = c00 + 0.3f; + cov2d.y = c01; + cov2d.z = c11 + 0.3f; + float det_blur = cov2d.x * cov2d.z - cov2d.y * cov2d.y; + compensation = std::sqrt(std::max(0.f, det_orig / det_blur)); } // device helper to get 3D covariance from scale and quat parameters diff --git a/gsplat/cuda/csrc/forward.cuh b/gsplat/cuda/csrc/forward.cuh index a4bd4acc6..a3acc8ca2 100644 --- a/gsplat/cuda/csrc/forward.cuh +++ b/gsplat/cuda/csrc/forward.cuh @@ -20,6 +20,7 @@ __global__ void project_gaussians_forward_kernel( float* __restrict__ depths, int* __restrict__ radii, float3* __restrict__ conics, + float* __restrict__ compensation, int32_t* __restrict__ num_tiles_hit ); @@ -57,14 +58,16 @@ __global__ void nd_rasterize_forward( ); // device helper to approximate projected 2d cov from 3d mean and cov -__device__ float3 project_cov3d_ewa( +__device__ void project_cov3d_ewa( const float3 &mean3d, const float *cov3d, const float *viewmat, const float fx, const float fy, const float tan_fovx, - const float tan_fovy + const float tan_fovy, + float3 &cov2d, + float &comp ); // device helper to get 3D covariance from scale and quat parameters diff --git a/gsplat/project_gaussians.py b/gsplat/project_gaussians.py index 4fd310fa6..91ccb529c 100644 --- a/gsplat/project_gaussians.py +++ b/gsplat/project_gaussians.py @@ -24,7 +24,7 @@ def project_gaussians( img_width: int, tile_bounds: Tuple[int, int, int], clip_thresh: float = 0.01, -) -> Tuple[Tensor, Tensor, Tensor, Tensor, int, Tensor]: +) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: """This function projects 3D gaussians to 2D using the EWA splatting method for gaussian splatting. Note: @@ -47,12 +47,13 @@ def project_gaussians( clip_thresh (float): minimum z depth threshold. Returns: - A tuple of {Tensor, Tensor, Tensor, Tensor, Tensor, Tensor}: + A tuple of {Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor}: - **xys** (Tensor): x,y locations of 2D gaussian projections. - **depths** (Tensor): z depth of gaussians. - **radii** (Tensor): radii of 2D gaussian projections. - **conics** (Tensor): conic parameters for 2D gaussian. + - **compensation** (Tensor): the density compensation for blurring 2D kernel - **num_tiles_hit** (Tensor): number of tiles hit per gaussian. - **cov3d** (Tensor): 3D covariances. """ @@ -105,6 +106,7 @@ def forward( depths, radii, conics, + compensation, num_tiles_hit, ) = _C.project_gaussians_forward( num_points, @@ -146,10 +148,19 @@ def forward( conics, ) - return (xys, depths, radii, conics, num_tiles_hit, cov3d) + return (xys, depths, radii, conics, compensation, num_tiles_hit, cov3d) @staticmethod - def backward(ctx, v_xys, v_depths, v_radii, v_conics, v_num_tiles_hit, v_cov3d): + def backward( + ctx, + v_xys, + v_depths, + v_radii, + v_conics, + v_compensation, + v_num_tiles_hit, + v_cov3d, + ): ( means3d, scales, diff --git a/tests/test_project_gaussians.py b/tests/test_project_gaussians.py index cd37677cf..159ac1c62 100644 --- a/tests/test_project_gaussians.py +++ b/tests/test_project_gaussians.py @@ -66,7 +66,15 @@ def test_project_gaussians_forward(): BLOCK_X, BLOCK_Y = 16, 16 tile_bounds = (W + BLOCK_X - 1) // BLOCK_X, (H + BLOCK_Y - 1) // BLOCK_Y, 1 - (cov3d, xys, depths, radii, conics, num_tiles_hit,) = _C.project_gaussians_forward( + ( + cov3d, + xys, + depths, + radii, + conics, + compensation, + num_tiles_hit, + ) = _C.project_gaussians_forward( num_points, means3d, scales, @@ -93,6 +101,7 @@ def test_project_gaussians_forward(): _depths, _radii, _conics, + _compensation, _num_tiles_hit, _masks, ) = _torch_impl.project_gaussians_forward( @@ -114,6 +123,7 @@ def test_project_gaussians_forward(): check_close(depths, _depths) check_close(radii, _radii) check_close(conics, _conics) + check_close(compensation, _compensation) check_close(num_tiles_hit, _num_tiles_hit) print("passed project_gaussians_forward test") @@ -156,6 +166,7 @@ def test_project_gaussians_backward(): radii, conics, _, + _, masks, ) = _torch_impl.project_gaussians_forward( means3d, @@ -219,7 +230,7 @@ def project_cov3d_ewa_partial(mean3d, cov3d): i, j = torch.triu_indices(3, 3) cov3d_mat[..., i, j] = cov3d cov3d_mat[..., [1, 2, 2], [0, 0, 1]] = cov3d[..., [1, 2, 4]] - cov2d = _torch_impl.project_cov3d_ewa( + cov2d, _ = _torch_impl.project_cov3d_ewa( mean3d, cov3d_mat, viewmat, fx, fy, tan_fovx, tan_fovy ) ii, jj = torch.triu_indices(2, 2)