From 9c06a245c1dab693b5815d79f727ddaacbbabaa5 Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Mon, 15 Jul 2024 11:44:58 -0700 Subject: [PATCH] support packed and sparse --- gsplat/cuda/_wrapper.py | 14 +++++++++++++- gsplat/cuda/csrc/bindings.h | 3 ++- gsplat/cuda/csrc/fully_fused_projection_fwd.cu | 2 +- .../csrc/fully_fused_projection_packed_bwd.cu | 17 ++++++++++++++++- .../csrc/fully_fused_projection_packed_fwd.cu | 17 +++++++++++++---- gsplat/rendering.py | 1 + 6 files changed, 46 insertions(+), 8 deletions(-) diff --git a/gsplat/cuda/_wrapper.py b/gsplat/cuda/_wrapper.py index 1dbadf9c9..b3a31f02f 100644 --- a/gsplat/cuda/_wrapper.py +++ b/gsplat/cuda/_wrapper.py @@ -958,6 +958,7 @@ def forward( radii, means2d, depths, + normals, conics, compensations, ) = _make_lazy_cuda_func("fully_fused_projection_packed_fwd")( @@ -994,7 +995,16 @@ def forward( ctx.eps2d = eps2d ctx.sparse_grad = sparse_grad - return camera_ids, gaussian_ids, radii, means2d, depths, conics, compensations + return ( + camera_ids, + gaussian_ids, + radii, + means2d, + depths, + normals, + conics, + compensations, + ) @staticmethod def backward( @@ -1004,6 +1014,7 @@ def backward( v_radii, v_means2d, v_depths, + v_normals, v_conics, v_compensations, ): @@ -1044,6 +1055,7 @@ def backward( compensations, v_means2d.contiguous(), v_depths.contiguous(), + v_normals.contiguous(), v_conics.contiguous(), v_compensations, ctx.needs_input_grad[4], # viewmats_requires_grad diff --git a/gsplat/cuda/csrc/bindings.h b/gsplat/cuda/csrc/bindings.h index bde378773..4350d9810 100644 --- a/gsplat/cuda/csrc/bindings.h +++ b/gsplat/cuda/csrc/bindings.h @@ -180,7 +180,7 @@ compute_sh_bwd_tensor(const uint32_t K, const uint32_t degrees_to_use, * Packed Version ****************************************************************************************/ std::tuple + torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> fully_fused_projection_packed_fwd_tensor( const torch::Tensor &means, // [N, 3] const at::optional &covars, // [N, 6] @@ -210,6 +210,7 @@ fully_fused_projection_packed_bwd_tensor( // grad outputs const torch::Tensor &v_means2d, // [nnz, 2] const torch::Tensor &v_depths, // [nnz] + const torch::Tensor &v_normals, // [nnz, 3] const torch::Tensor &v_conics, // [nnz, 3] const at::optional &v_compensations, // [nnz] optional const bool viewmats_requires_grad, const bool sparse_grad); diff --git a/gsplat/cuda/csrc/fully_fused_projection_fwd.cu b/gsplat/cuda/csrc/fully_fused_projection_fwd.cu index c491b1b6f..8a44269f7 100644 --- a/gsplat/cuda/csrc/fully_fused_projection_fwd.cu +++ b/gsplat/cuda/csrc/fully_fused_projection_fwd.cu @@ -31,7 +31,7 @@ fully_fused_projection_fwd_kernel(const uint32_t C, const uint32_t N, int32_t *__restrict__ radii, // [C, N] T *__restrict__ means2d, // [C, N, 2] T *__restrict__ depths, // [C, N] - T *__restrict__ normals, // [C, N, 3] + T *__restrict__ normals, // [C, N, 3] T *__restrict__ conics, // [C, N, 3] T *__restrict__ compensations // [C, N] optional ) { diff --git a/gsplat/cuda/csrc/fully_fused_projection_packed_bwd.cu b/gsplat/cuda/csrc/fully_fused_projection_packed_bwd.cu index 476390814..661228c39 100644 --- a/gsplat/cuda/csrc/fully_fused_projection_packed_bwd.cu +++ b/gsplat/cuda/csrc/fully_fused_projection_packed_bwd.cu @@ -34,6 +34,7 @@ __global__ void fully_fused_projection_packed_bwd_kernel( // grad outputs const T *__restrict__ v_means2d, // [nnz, 2] const T *__restrict__ v_depths, // [nnz] + const T *__restrict__ v_normals, // [nnz, 3] const T *__restrict__ v_conics, // [nnz, 3] const T *__restrict__ v_compensations, // [nnz] optional const bool sparse_grad, // whether the outputs are in COO format [nnz, ...] @@ -61,6 +62,7 @@ __global__ void fully_fused_projection_packed_bwd_kernel( v_means2d += idx * 2; v_depths += idx; + v_normals += idx * 3; v_conics += idx * 3; // vjp: compute the inverse of the 2d covariance @@ -154,6 +156,11 @@ __global__ void fully_fused_projection_packed_bwd_kernel( v_scales[0] = v_scale[0]; v_scales[1] = v_scale[1]; v_scales[2] = v_scale[2]; + + // add contribution from v_normals. Please check if this is correct. + mat3 v_R = quat_to_rotmat(quat); + v_R[2] += glm::make_vec3(v_normals); + quat_to_rotmat_vjp(quat, v_R, v_quat); } } else { // write out results with dense layout @@ -188,6 +195,12 @@ __global__ void fully_fused_projection_packed_bwd_kernel( vec4 v_quat(0.f); vec3 v_scale(0.f); quat_scale_to_covar_vjp(quat, scale, rotmat, v_covar, v_quat, v_scale); + + // add contribution from v_normals. Please check if this is correct. + mat3 v_R = quat_to_rotmat(quat); + v_R[2] += glm::make_vec3(v_normals); + quat_to_rotmat_vjp(quat, v_R, v_quat); + warpSum(v_quat, warp_group_g); warpSum(v_scale, warp_group_g); if (warp_group_g.thread_rank() == 0) { @@ -240,6 +253,7 @@ fully_fused_projection_packed_bwd_tensor( // grad outputs const torch::Tensor &v_means2d, // [nnz, 2] const torch::Tensor &v_depths, // [nnz] + const torch::Tensor &v_normals, // [nnz, 3] const torch::Tensor &v_conics, // [nnz, 3] const at::optional &v_compensations, // [nnz] optional const bool viewmats_requires_grad, const bool sparse_grad) { @@ -259,6 +273,7 @@ fully_fused_projection_packed_bwd_tensor( CHECK_INPUT(conics); CHECK_INPUT(v_means2d); CHECK_INPUT(v_depths); + CHECK_INPUT(v_normals); CHECK_INPUT(v_conics); if (compensations.has_value()) { CHECK_INPUT(compensations.value()); @@ -309,7 +324,7 @@ fully_fused_projection_packed_bwd_tensor( compensations.has_value() ? compensations.value().data_ptr() : nullptr, v_means2d.data_ptr(), v_depths.data_ptr(), - v_conics.data_ptr(), + v_normals.data_ptr(), v_conics.data_ptr(), v_compensations.has_value() ? v_compensations.value().data_ptr() : nullptr, sparse_grad, v_means.data_ptr(), diff --git a/gsplat/cuda/csrc/fully_fused_projection_packed_fwd.cu b/gsplat/cuda/csrc/fully_fused_projection_packed_fwd.cu index 7f7082f6e..719895e38 100644 --- a/gsplat/cuda/csrc/fully_fused_projection_packed_fwd.cu +++ b/gsplat/cuda/csrc/fully_fused_projection_packed_fwd.cu @@ -36,6 +36,7 @@ __global__ void fully_fused_projection_packed_fwd_kernel( int32_t *__restrict__ radii, // [nnz] T *__restrict__ means2d, // [nnz, 2] T *__restrict__ depths, // [nnz] + T *__restrict__ normals, // [nnz, 3] T *__restrict__ conics, // [nnz, 3] T *__restrict__ compensations // [nnz] optional ) { @@ -75,6 +76,7 @@ __global__ void fully_fused_projection_packed_fwd_kernel( mat2 covar2d; vec2 mean2d; mat2 covar2d_inv; + vec3 normal; T compensation; T det; if (valid) { @@ -92,6 +94,9 @@ __global__ void fully_fused_projection_packed_fwd_kernel( quats += col_idx * 4; scales += col_idx * 3; quat_scale_to_covar_preci(glm::make_vec4(quats), glm::make_vec3(scales), &covar, nullptr); + + glm::mat3 rotmat = quat_to_rotmat(glm::make_vec4(quats)); + normal = rotmat[2]; } mat3 covar_c; covar_world_to_cam(R, covar, covar_c); @@ -163,6 +168,9 @@ __global__ void fully_fused_projection_packed_fwd_kernel( means2d[thread_data * 2] = mean2d.x; means2d[thread_data * 2 + 1] = mean2d.y; depths[thread_data] = mean_c.z; + normals[thread_data * 3] = normal.x; + normals[thread_data * 3 + 1] = normal.y; + normals[thread_data * 3 + 2] = normal.z; conics[thread_data * 3] = covar2d_inv[0][0]; conics[thread_data * 3 + 1] = covar2d_inv[0][1]; conics[thread_data * 3 + 2] = covar2d_inv[1][1]; @@ -183,7 +191,7 @@ __global__ void fully_fused_projection_packed_fwd_kernel( } std::tuple + torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> fully_fused_projection_packed_fwd_tensor( const torch::Tensor &means, // [N, 3] const at::optional &covars, // [N, 6] @@ -232,7 +240,7 @@ fully_fused_projection_packed_fwd_tensor( viewmats.data_ptr(), Ks.data_ptr(), image_width, image_height, eps2d, near_plane, far_plane, radius_clip, nullptr, block_cnts.data_ptr(), nullptr, nullptr, nullptr, nullptr, nullptr, - nullptr, nullptr, nullptr); + nullptr, nullptr, nullptr, nullptr); block_accum = torch::cumsum(block_cnts, 0, torch::kInt32); nnz = block_accum[-1].item(); } else { @@ -246,6 +254,7 @@ fully_fused_projection_packed_fwd_tensor( torch::Tensor radii = torch::empty({nnz}, means.options().dtype(torch::kInt32)); torch::Tensor means2d = torch::empty({nnz, 2}, means.options()); torch::Tensor depths = torch::empty({nnz}, means.options()); + torch::Tensor normals = torch::empty({nnz, 3}, means.options()); torch::Tensor conics = torch::empty({nnz, 3}, means.options()); torch::Tensor compensations; if (calc_compensations) { @@ -264,12 +273,12 @@ fully_fused_projection_packed_fwd_tensor( nullptr, indptr.data_ptr(), camera_ids.data_ptr(), gaussian_ids.data_ptr(), radii.data_ptr(), means2d.data_ptr(), depths.data_ptr(), - conics.data_ptr(), + normals.data_ptr(), conics.data_ptr(), calc_compensations ? compensations.data_ptr() : nullptr); } else { indptr.fill_(0); } return std::make_tuple(indptr, camera_ids, gaussian_ids, radii, means2d, depths, - conics, compensations); + normals, conics, compensations); } diff --git a/gsplat/rendering.py b/gsplat/rendering.py index 40d548dc3..3ef25bf6e 100644 --- a/gsplat/rendering.py +++ b/gsplat/rendering.py @@ -242,6 +242,7 @@ def rasterization( radii, means2d, depths, + normals, conics, compensations, ) = proj_results