From 7c78f4b1543e5b995c593bb000a8d621ad687ae2 Mon Sep 17 00:00:00 2001 From: Ruilong Li <397653553@qq.com> Date: Wed, 7 Aug 2024 21:09:34 +0000 Subject: [PATCH] rasterize_to_pixels_fwd_kernel --- gsplat/cuda/csrc/bindings.h | 6 ++- gsplat/cuda/csrc/rasterize_to_pixels_fwd.cu | 44 ++++++++++++++++----- 2 files changed, 38 insertions(+), 12 deletions(-) diff --git a/gsplat/cuda/csrc/bindings.h b/gsplat/cuda/csrc/bindings.h index baa78fbf5..1e5e5245d 100644 --- a/gsplat/cuda/csrc/bindings.h +++ b/gsplat/cuda/csrc/bindings.h @@ -111,7 +111,8 @@ torch::Tensor isect_offset_encode_tensor(const torch::Tensor &isect_ids, // [n_i const uint32_t C, const uint32_t tile_width, const uint32_t tile_height); -std::tuple rasterize_to_pixels_fwd_tensor( +std::tuple +rasterize_to_pixels_fwd_tensor( // Gaussian parameters const torch::Tensor &means2d, // [C, N, 2] const torch::Tensor &conics, // [C, N, 3] @@ -122,7 +123,8 @@ std::tuple rasterize_to_pixels_fwd_ const uint32_t image_width, const uint32_t image_height, const uint32_t tile_size, // intersections const torch::Tensor &tile_offsets, // [C, tile_height, tile_width] - const torch::Tensor &flatten_ids // [n_isects] + const torch::Tensor &flatten_ids, // [n_isects] + const bool calc_depth // whether to calculate depth ); std::tuple diff --git a/gsplat/cuda/csrc/rasterize_to_pixels_fwd.cu b/gsplat/cuda/csrc/rasterize_to_pixels_fwd.cu index 8ecd569f2..c1dcf9aa1 100644 --- a/gsplat/cuda/csrc/rasterize_to_pixels_fwd.cu +++ b/gsplat/cuda/csrc/rasterize_to_pixels_fwd.cu @@ -25,6 +25,7 @@ __global__ void rasterize_to_pixels_fwd_kernel( const int32_t *__restrict__ flatten_ids, // [n_isects] S *__restrict__ render_colors, // [C, image_height, image_width, COLOR_DIM] S *__restrict__ render_alphas, // [C, image_height, image_width, 1] + S *__restrict__ render_depths, // [C, image_height, image_width, 1] optional int32_t *__restrict__ last_ids // [C, image_height, image_width] ) { // each thread draws one pixel, but also timeshares caching gaussians in a @@ -39,6 +40,9 @@ __global__ void rasterize_to_pixels_fwd_kernel( tile_offsets += camera_id * tile_height * tile_width; render_colors += camera_id * image_height * image_width * COLOR_DIM; render_alphas += camera_id * image_height * image_width; + if (render_depths != nullptr) { + render_depths += camera_id * image_height * image_width; + } last_ids += camera_id * image_height * image_width; if (backgrounds != nullptr) { backgrounds += camera_id * COLOR_DIM; @@ -84,6 +88,7 @@ __global__ void rasterize_to_pixels_fwd_kernel( uint32_t tr = block.thread_rank(); S pix_out[COLOR_DIM] = {0.f}; + S depth_out = 0.f; for (uint32_t b = 0; b < num_batches; ++b) { // resync all threads before beginning next batch // end early if entire tile is done @@ -140,10 +145,18 @@ __global__ void rasterize_to_pixels_fwd_kernel( int32_t g = id_batch[t]; const S vis = alpha * T; const S *c_ptr = colors + g * COLOR_DIM; + // accumulate color PRAGMA_UNROLL for (uint32_t k = 0; k < COLOR_DIM; ++k) { pix_out[k] += c_ptr[k] * vis; } + // accumulate depth + if (render_depths != nullptr) { + S depth = + mean2d.z + + (conic02 * (mean2d.x - px) + conic12 * (mean2d.y - py)) / conic22; + depth_out += depth * vis; + } cur_idx = batch_start + t; T = next_T; @@ -162,13 +175,17 @@ __global__ void rasterize_to_pixels_fwd_kernel( render_colors[pix_id * COLOR_DIM + k] = backgrounds == nullptr ? pix_out[k] : (pix_out[k] + T * backgrounds[k]); } + if (render_depths != nullptr) { + render_depths[pix_id] = depth_out; + } // index in bin of last gaussian in this pixel last_ids[pix_id] = static_cast(cur_idx); } } template -std::tuple call_kernel_with_dim( +std::tuple +call_kernel_with_dim( // Gaussian parameters const torch::Tensor &means2d, // [C, N, 3] or [nnz, 3] const torch::Tensor &conics, // [C, N, 6] or [nnz, 6] @@ -179,8 +196,8 @@ std::tuple call_kernel_with_dim( const uint32_t image_width, const uint32_t image_height, const uint32_t tile_size, // intersections const torch::Tensor &tile_offsets, // [C, tile_height, tile_width] - const torch::Tensor &flatten_ids // [n_isects] -) { + const torch::Tensor &flatten_ids, // [n_isects] + bool calc_depth) { DEVICE_GUARD(means2d); CHECK_INPUT(means2d); CHECK_INPUT(conics); @@ -209,6 +226,11 @@ std::tuple call_kernel_with_dim( means2d.options().dtype(torch::kFloat32)); torch::Tensor alphas = torch::empty({C, image_height, image_width, 1}, means2d.options().dtype(torch::kFloat32)); + torch::Tensor depths; + if (calc_depth) { + depths = torch::empty({C, image_height, image_width, 1}, + means2d.options().dtype(torch::kFloat32)); + } torch::Tensor last_ids = torch::empty({C, image_height, image_width}, means2d.options().dtype(torch::kInt32)); @@ -236,12 +258,14 @@ std::tuple call_kernel_with_dim( image_width, image_height, tile_size, tile_width, tile_height, tile_offsets.data_ptr(), flatten_ids.data_ptr(), renders.data_ptr(), alphas.data_ptr(), + calc_depth ? depths.data_ptr() : nullptr, last_ids.data_ptr()); - return std::make_tuple(renders, alphas, last_ids); + return std::make_tuple(renders, alphas, depths, last_ids); } -std::tuple rasterize_to_pixels_fwd_tensor( +std::tuple +rasterize_to_pixels_fwd_tensor( // Gaussian parameters const torch::Tensor &means2d, // [C, N, 3] or [nnz, 3] const torch::Tensor &conics, // [C, N, 6] or [nnz, 6] @@ -252,16 +276,16 @@ std::tuple rasterize_to_pixels_fwd_ const uint32_t image_width, const uint32_t image_height, const uint32_t tile_size, // intersections const torch::Tensor &tile_offsets, // [C, tile_height, tile_width] - const torch::Tensor &flatten_ids // [n_isects] -) { + const torch::Tensor &flatten_ids, // [n_isects] + bool calc_depth) { CHECK_INPUT(colors); uint32_t channels = colors.size(-1); #define __GS__CALL_(N) \ case N: \ - return call_kernel_with_dim(means2d, conics, colors, opacities, \ - backgrounds, image_width, image_height, \ - tile_size, tile_offsets, flatten_ids); + return call_kernel_with_dim( \ + means2d, conics, colors, opacities, backgrounds, image_width, \ + image_height, tile_size, tile_offsets, flatten_ids, calc_depth); // TODO: an optimization can be done by passing the actual number of channels into // the kernel functions and avoid necessary global memory writes. This requires