diff --git a/gsplat/cuda/_wrapper.py b/gsplat/cuda/_wrapper.py index 9043181ce..2fa0c4937 100644 --- a/gsplat/cuda/_wrapper.py +++ b/gsplat/cuda/_wrapper.py @@ -379,6 +379,7 @@ def rasterize_to_pixels( isect_offsets: Tensor, # [C, tile_height, tile_width] flatten_ids: Tensor, # [n_isects] backgrounds: Optional[Tensor] = None, # [C, channels] + masks: Optional[Tensor] = None, # [C, tile_height, tile_width] packed: bool = False, absgrad: bool = False, ) -> Tuple[Tensor, Tensor]: @@ -395,6 +396,7 @@ def rasterize_to_pixels( isect_offsets: Intersection offsets outputs from `isect_offset_encode()`. [C, tile_height, tile_width] flatten_ids: The global flatten indices in [C * N] or [nnz] from `isect_tiles()`. [n_isects] backgrounds: Background colors. [C, channels]. Default: None. + masks: Optional tile mask to skip rendering GS to masked tiles. [C, tile_height, tile_width]. Default: None. packed: If True, the input tensors are expected to be packed with shape [nnz, ...]. Default: False. absgrad: If True, the backward pass will compute a `.absgrad` attribute for `means2d`. Default: False. @@ -422,6 +424,9 @@ def rasterize_to_pixels( if backgrounds is not None: assert backgrounds.shape == (C, colors.shape[-1]), backgrounds.shape backgrounds = backgrounds.contiguous() + if masks is not None: + assert masks.shape == isect_offsets.shape, masks.shape + masks = masks.contiguous() # Pad the channels to the nearest supported number if necessary channels = colors.shape[-1] @@ -484,6 +489,7 @@ def rasterize_to_pixels( colors.contiguous(), opacities.contiguous(), backgrounds, + masks, image_width, image_height, tile_size, @@ -814,6 +820,7 @@ def forward( colors: Tensor, # [C, N, D] opacities: Tensor, # [C, N] backgrounds: Tensor, # [C, D], Optional + masks: Tensor, # [C, tile_height, tile_width], Optional width: int, height: int, tile_size: int, @@ -829,6 +836,7 @@ def forward( colors, opacities, backgrounds, + masks, width, height, tile_size, @@ -842,6 +850,7 @@ def forward( colors, opacities, backgrounds, + masks, isect_offsets, flatten_ids, render_alphas, @@ -868,6 +877,7 @@ def backward( colors, opacities, backgrounds, + masks, isect_offsets, flatten_ids, render_alphas, @@ -890,6 +900,7 @@ def backward( colors, opacities, backgrounds, + masks, width, height, tile_size, @@ -924,6 +935,7 @@ def backward( None, None, None, + None, ) diff --git a/gsplat/cuda/csrc/bindings.h b/gsplat/cuda/csrc/bindings.h index c983f461e..faf28bad0 100644 --- a/gsplat/cuda/csrc/bindings.h +++ b/gsplat/cuda/csrc/bindings.h @@ -119,6 +119,7 @@ std::tuple rasterize_to_pixels_fwd_ const torch::Tensor &colors, // [C, N, D] const torch::Tensor &opacities, // [N] const at::optional &backgrounds, // [C, D] + const at::optional &mask, // [C, tile_height, tile_width] // image size const uint32_t image_width, const uint32_t image_height, const uint32_t tile_size, // intersections @@ -134,6 +135,7 @@ rasterize_to_pixels_bwd_tensor( const torch::Tensor &colors, // [C, N, 3] const torch::Tensor &opacities, // [N] const at::optional &backgrounds, // [C, 3] + const at::optional &mask, // [C, tile_height, tile_width] // image size const uint32_t image_width, const uint32_t image_height, const uint32_t tile_size, // intersections diff --git a/gsplat/cuda/csrc/rasterize_to_pixels_bwd.cu b/gsplat/cuda/csrc/rasterize_to_pixels_bwd.cu index f17b5aedd..0ec2870e8 100644 --- a/gsplat/cuda/csrc/rasterize_to_pixels_bwd.cu +++ b/gsplat/cuda/csrc/rasterize_to_pixels_bwd.cu @@ -20,6 +20,7 @@ __global__ void rasterize_to_pixels_bwd_kernel( const S *__restrict__ colors, // [C, N, COLOR_DIM] or [nnz, COLOR_DIM] const S *__restrict__ opacities, // [C, N] or [nnz] const S *__restrict__ backgrounds, // [C, COLOR_DIM] or [nnz, COLOR_DIM] + const bool *__restrict__ masks, // [C, tile_height, tile_width] const uint32_t image_width, const uint32_t image_height, const uint32_t tile_size, const uint32_t tile_width, const uint32_t tile_height, const int32_t *__restrict__ tile_offsets, // [C, tile_height, tile_width] @@ -52,6 +53,15 @@ __global__ void rasterize_to_pixels_bwd_kernel( if (backgrounds != nullptr) { backgrounds += camera_id * COLOR_DIM; } + if (masks != nullptr) { + masks += camera_id * tile_height * tile_width; + } + + // when the mask is provided, do nothing and return if + // this tile is labeled as False + if (masks != nullptr && !masks[tile_id]) { + return; + } const S px = (S)j + 0.5f; const S py = (S)i + 0.5f; @@ -257,6 +267,7 @@ call_kernel_with_dim( const torch::Tensor &colors, // [C, N, 3] or [nnz, 3] const torch::Tensor &opacities, // [C, N] or [nnz] const at::optional &backgrounds, // [C, 3] + const at::optional &masks, // [C, tile_height, tile_width] // image size const uint32_t image_width, const uint32_t image_height, const uint32_t tile_size, // intersections @@ -285,6 +296,9 @@ call_kernel_with_dim( if (backgrounds.has_value()) { CHECK_INPUT(backgrounds.value()); } + if (masks.has_value()) { + CHECK_INPUT(masks.value()); + } bool packed = means2d.dim() == 2; @@ -329,6 +343,7 @@ call_kernel_with_dim( colors.data_ptr(), opacities.data_ptr(), backgrounds.has_value() ? backgrounds.value().data_ptr() : nullptr, + masks.has_value() ? masks.value().data_ptr(): nullptr, image_width, image_height, tile_size, tile_width, tile_height, tile_offsets.data_ptr(), flatten_ids.data_ptr(), render_alphas.data_ptr(), last_ids.data_ptr(), @@ -352,6 +367,7 @@ rasterize_to_pixels_bwd_tensor( const torch::Tensor &colors, // [C, N, 3] or [nnz, 3] const torch::Tensor &opacities, // [C, N] or [nnz] const at::optional &backgrounds, // [C, 3] + const at::optional &masks, // [C, tile_height, tile_width] // image size const uint32_t image_width, const uint32_t image_height, const uint32_t tile_size, // intersections @@ -372,7 +388,7 @@ rasterize_to_pixels_bwd_tensor( #define __GS__CALL_(N) \ case N: \ return call_kernel_with_dim( \ - means2d, conics, colors, opacities, backgrounds, image_width, \ + means2d, conics, colors, opacities, backgrounds, masks, image_width, \ image_height, tile_size, tile_offsets, flatten_ids, render_alphas, \ last_ids, v_render_colors, v_render_alphas, absgrad); diff --git a/gsplat/cuda/csrc/rasterize_to_pixels_fwd.cu b/gsplat/cuda/csrc/rasterize_to_pixels_fwd.cu index 3c6692e53..6a89b7148 100644 --- a/gsplat/cuda/csrc/rasterize_to_pixels_fwd.cu +++ b/gsplat/cuda/csrc/rasterize_to_pixels_fwd.cu @@ -19,6 +19,7 @@ __global__ void rasterize_to_pixels_fwd_kernel( const S *__restrict__ colors, // [C, N, COLOR_DIM] or [nnz, COLOR_DIM] const S *__restrict__ opacities, // [C, N] or [nnz] const S *__restrict__ backgrounds, // [C, COLOR_DIM] + const bool *__restrict__ masks, // [C, tile_height, tile_width] const uint32_t image_width, const uint32_t image_height, const uint32_t tile_size, const uint32_t tile_width, const uint32_t tile_height, const int32_t *__restrict__ tile_offsets, // [C, tile_height, tile_width] @@ -43,6 +44,9 @@ __global__ void rasterize_to_pixels_fwd_kernel( if (backgrounds != nullptr) { backgrounds += camera_id * COLOR_DIM; } + if (masks != nullptr) { + masks += camera_id * tile_height * tile_width; + } S px = (S)j + 0.5f; S py = (S)i + 0.5f; @@ -53,6 +57,15 @@ __global__ void rasterize_to_pixels_fwd_kernel( bool inside = (i < image_height && j < image_width); bool done = !inside; + // when the mask is provided, render the background color and return + // if this tile is labeled as False + if (masks != nullptr && inside && !masks[tile_id]) { + for (uint32_t k = 0; k < COLOR_DIM; ++k) { + render_colors[pix_id * COLOR_DIM + k] = backgrounds == nullptr ? 0.0f : backgrounds[k]; + } + return; + } + // have all threads in tile process the same gaussians in batches // first collect gaussians between range.x and range.y in batches // which gaussians to look through in this tile @@ -167,6 +180,7 @@ std::tuple call_kernel_with_dim( const torch::Tensor &colors, // [C, N, channels] or [nnz, channels] const torch::Tensor &opacities, // [C, N] or [nnz] const at::optional &backgrounds, // [C, channels] + const at::optional &masks, // [C, tile_height, tile_width] // image size const uint32_t image_width, const uint32_t image_height, const uint32_t tile_size, // intersections @@ -183,6 +197,9 @@ std::tuple call_kernel_with_dim( if (backgrounds.has_value()) { CHECK_INPUT(backgrounds.value()); } + if (masks.has_value()) { + CHECK_INPUT(masks.value()); + } bool packed = means2d.dim() == 2; uint32_t C = tile_offsets.size(0); // number of cameras @@ -225,6 +242,7 @@ std::tuple call_kernel_with_dim( reinterpret_cast *>(conics.data_ptr()), colors.data_ptr(), opacities.data_ptr(), backgrounds.has_value() ? backgrounds.value().data_ptr() : nullptr, + masks.has_value() ? masks.value().data_ptr() : nullptr, image_width, image_height, tile_size, tile_width, tile_height, tile_offsets.data_ptr(), flatten_ids.data_ptr(), renders.data_ptr(), alphas.data_ptr(), @@ -240,6 +258,7 @@ std::tuple rasterize_to_pixels_fwd_ const torch::Tensor &colors, // [C, N, channels] or [nnz, channels] const torch::Tensor &opacities, // [C, N] or [nnz] const at::optional &backgrounds, // [C, channels] + const at::optional &masks, // [C, tile_height, tile_width] // image size const uint32_t image_width, const uint32_t image_height, const uint32_t tile_size, // intersections @@ -252,7 +271,7 @@ std::tuple rasterize_to_pixels_fwd_ #define __GS__CALL_(N) \ case N: \ return call_kernel_with_dim(means2d, conics, colors, opacities, \ - backgrounds, image_width, image_height, \ + backgrounds, masks, image_width, image_height, \ tile_size, tile_offsets, flatten_ids); // TODO: an optimization can be done by passing the actual number of channels into