Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add optional masks in rasterize_to_pixels() to support Grendel #284

Merged
merged 3 commits into from
Jul 17, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions gsplat/cuda/_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,7 @@ def rasterize_to_pixels(
isect_offsets: Tensor, # [C, tile_height, tile_width]
flatten_ids: Tensor, # [n_isects]
backgrounds: Optional[Tensor] = None, # [C, channels]
masks: Optional[Tensor] = None, # [C, tile_height, tile_width]
packed: bool = False,
absgrad: bool = False,
) -> Tuple[Tensor, Tensor]:
Expand All @@ -395,6 +396,7 @@ def rasterize_to_pixels(
isect_offsets: Intersection offsets outputs from `isect_offset_encode()`. [C, tile_height, tile_width]
flatten_ids: The global flatten indices in [C * N] or [nnz] from `isect_tiles()`. [n_isects]
backgrounds: Background colors. [C, channels]. Default: None.
masks: Optional masks to support Grendel's local workload strategy. [C, tile_height, tile_width]. Default: None.
liruilong940607 marked this conversation as resolved.
Show resolved Hide resolved
packed: If True, the input tensors are expected to be packed with shape [nnz, ...]. Default: False.
absgrad: If True, the backward pass will compute a `.absgrad` attribute for `means2d`. Default: False.

Expand Down Expand Up @@ -422,6 +424,9 @@ def rasterize_to_pixels(
if backgrounds is not None:
assert backgrounds.shape == (C, colors.shape[-1]), backgrounds.shape
backgrounds = backgrounds.contiguous()
if masks is not None:
assert masks.shape[0] == C, masks.shape
liruilong940607 marked this conversation as resolved.
Show resolved Hide resolved
masks = masks.contiguous()

# Pad the channels to the nearest supported number if necessary
channels = colors.shape[-1]
Expand Down Expand Up @@ -484,6 +489,7 @@ def rasterize_to_pixels(
colors.contiguous(),
opacities.contiguous(),
backgrounds,
masks,
image_width,
image_height,
tile_size,
Expand Down Expand Up @@ -814,6 +820,7 @@ def forward(
colors: Tensor, # [C, N, D]
opacities: Tensor, # [C, N]
backgrounds: Tensor, # [C, D], Optional
masks: Tensor, # [C, tile_height, tile_width], Optional
width: int,
height: int,
tile_size: int,
Expand All @@ -829,6 +836,7 @@ def forward(
colors,
opacities,
backgrounds,
masks,
width,
height,
tile_size,
Expand All @@ -842,6 +850,7 @@ def forward(
colors,
opacities,
backgrounds,
masks,
isect_offsets,
flatten_ids,
render_alphas,
Expand All @@ -868,6 +877,7 @@ def backward(
colors,
opacities,
backgrounds,
masks,
isect_offsets,
flatten_ids,
render_alphas,
Expand All @@ -890,6 +900,7 @@ def backward(
colors,
opacities,
backgrounds,
masks,
width,
height,
tile_size,
Expand Down Expand Up @@ -924,6 +935,7 @@ def backward(
None,
None,
None,
None,
)


Expand Down
2 changes: 2 additions & 0 deletions gsplat/cuda/csrc/bindings.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> rasterize_to_pixels_fwd_
const torch::Tensor &colors, // [C, N, D]
const torch::Tensor &opacities, // [N]
const at::optional<torch::Tensor> &backgrounds, // [C, D]
const at::optional<torch::Tensor> &mask, // [C, tile_height, tile_width]
// image size
const uint32_t image_width, const uint32_t image_height, const uint32_t tile_size,
// intersections
Expand All @@ -134,6 +135,7 @@ rasterize_to_pixels_bwd_tensor(
const torch::Tensor &colors, // [C, N, 3]
const torch::Tensor &opacities, // [N]
const at::optional<torch::Tensor> &backgrounds, // [C, 3]
const at::optional<torch::Tensor> &mask, // [C, tile_height, tile_width]
// image size
const uint32_t image_width, const uint32_t image_height, const uint32_t tile_size,
// intersections
Expand Down
18 changes: 17 additions & 1 deletion gsplat/cuda/csrc/rasterize_to_pixels_bwd.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ __global__ void rasterize_to_pixels_bwd_kernel(
const S *__restrict__ colors, // [C, N, COLOR_DIM] or [nnz, COLOR_DIM]
const S *__restrict__ opacities, // [C, N] or [nnz]
const S *__restrict__ backgrounds, // [C, COLOR_DIM] or [nnz, COLOR_DIM]
const bool *__restrict__ masks, // [C, tile_height, tile_width]
const uint32_t image_width, const uint32_t image_height, const uint32_t tile_size,
const uint32_t tile_width, const uint32_t tile_height,
const int32_t *__restrict__ tile_offsets, // [C, tile_height, tile_width]
Expand Down Expand Up @@ -52,6 +53,15 @@ __global__ void rasterize_to_pixels_bwd_kernel(
if (backgrounds != nullptr) {
backgrounds += camera_id * COLOR_DIM;
}
if (masks != nullptr) {
masks += camera_id * tile_height * tile_width;
}

// when the mask is provided, do nothing and return if
// this tile is labeled as not local
liruilong940607 marked this conversation as resolved.
Show resolved Hide resolved
if (masks != nullptr && !masks[tile_id]) {
return;
}

const S px = (S)j + 0.5f;
const S py = (S)i + 0.5f;
Expand Down Expand Up @@ -257,6 +267,7 @@ call_kernel_with_dim(
const torch::Tensor &colors, // [C, N, 3] or [nnz, 3]
const torch::Tensor &opacities, // [C, N] or [nnz]
const at::optional<torch::Tensor> &backgrounds, // [C, 3]
const at::optional<torch::Tensor> &masks, // [C, tile_height, tile_width]
// image size
const uint32_t image_width, const uint32_t image_height, const uint32_t tile_size,
// intersections
Expand Down Expand Up @@ -285,6 +296,9 @@ call_kernel_with_dim(
if (backgrounds.has_value()) {
CHECK_INPUT(backgrounds.value());
}
if (masks.has_value()) {
CHECK_INPUT(masks.value());
}

bool packed = means2d.dim() == 2;

Expand Down Expand Up @@ -329,6 +343,7 @@ call_kernel_with_dim(
colors.data_ptr<float>(), opacities.data_ptr<float>(),
backgrounds.has_value() ? backgrounds.value().data_ptr<float>()
: nullptr,
masks.has_value() ? masks.value().data_ptr<bool>(): nullptr,
image_width, image_height, tile_size, tile_width, tile_height,
tile_offsets.data_ptr<int32_t>(), flatten_ids.data_ptr<int32_t>(),
render_alphas.data_ptr<float>(), last_ids.data_ptr<int32_t>(),
Expand All @@ -352,6 +367,7 @@ rasterize_to_pixels_bwd_tensor(
const torch::Tensor &colors, // [C, N, 3] or [nnz, 3]
const torch::Tensor &opacities, // [C, N] or [nnz]
const at::optional<torch::Tensor> &backgrounds, // [C, 3]
const at::optional<torch::Tensor> &masks, // [C, tile_height, tile_width]
// image size
const uint32_t image_width, const uint32_t image_height, const uint32_t tile_size,
// intersections
Expand All @@ -372,7 +388,7 @@ rasterize_to_pixels_bwd_tensor(
#define __GS__CALL_(N) \
case N: \
return call_kernel_with_dim<N>( \
means2d, conics, colors, opacities, backgrounds, image_width, \
means2d, conics, colors, opacities, backgrounds, masks, image_width, \
image_height, tile_size, tile_offsets, flatten_ids, render_alphas, \
last_ids, v_render_colors, v_render_alphas, absgrad);

Expand Down
21 changes: 20 additions & 1 deletion gsplat/cuda/csrc/rasterize_to_pixels_fwd.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ __global__ void rasterize_to_pixels_fwd_kernel(
const S *__restrict__ colors, // [C, N, COLOR_DIM] or [nnz, COLOR_DIM]
const S *__restrict__ opacities, // [C, N] or [nnz]
const S *__restrict__ backgrounds, // [C, COLOR_DIM]
const bool *__restrict__ masks, // [C, tile_height, tile_width]
const uint32_t image_width, const uint32_t image_height, const uint32_t tile_size,
const uint32_t tile_width, const uint32_t tile_height,
const int32_t *__restrict__ tile_offsets, // [C, tile_height, tile_width]
Expand All @@ -43,6 +44,9 @@ __global__ void rasterize_to_pixels_fwd_kernel(
if (backgrounds != nullptr) {
backgrounds += camera_id * COLOR_DIM;
}
if (masks != nullptr) {
masks += camera_id * tile_height * tile_width;
}

S px = (S)j + 0.5f;
S py = (S)i + 0.5f;
Expand All @@ -53,6 +57,15 @@ __global__ void rasterize_to_pixels_fwd_kernel(
bool inside = (i < image_height && j < image_width);
bool done = !inside;

// when the mask is provided, render 0.0 and return if this tile
liruilong940607 marked this conversation as resolved.
Show resolved Hide resolved
// is labeled as not local
if (masks != nullptr && inside && !masks[tile_id]) {
for (uint32_t k = 0; k < COLOR_DIM; ++k) {
render_colors[pix_id * COLOR_DIM + k] = 0.0f;
liruilong940607 marked this conversation as resolved.
Show resolved Hide resolved
}
return;
}

// have all threads in tile process the same gaussians in batches
// first collect gaussians between range.x and range.y in batches
// which gaussians to look through in this tile
Expand Down Expand Up @@ -167,6 +180,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> call_kernel_with_dim(
const torch::Tensor &colors, // [C, N, channels] or [nnz, channels]
const torch::Tensor &opacities, // [C, N] or [nnz]
const at::optional<torch::Tensor> &backgrounds, // [C, channels]
const at::optional<torch::Tensor> &masks, // [C, tile_height, tile_width]
// image size
const uint32_t image_width, const uint32_t image_height, const uint32_t tile_size,
// intersections
Expand All @@ -183,6 +197,9 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> call_kernel_with_dim(
if (backgrounds.has_value()) {
CHECK_INPUT(backgrounds.value());
}
if (masks.has_value()) {
CHECK_INPUT(masks.value());
}
bool packed = means2d.dim() == 2;

uint32_t C = tile_offsets.size(0); // number of cameras
Expand Down Expand Up @@ -225,6 +242,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> call_kernel_with_dim(
reinterpret_cast<vec3<float> *>(conics.data_ptr<float>()),
colors.data_ptr<float>(), opacities.data_ptr<float>(),
backgrounds.has_value() ? backgrounds.value().data_ptr<float>() : nullptr,
masks.has_value() ? masks.value().data_ptr<bool>() : nullptr,
image_width, image_height, tile_size, tile_width, tile_height,
tile_offsets.data_ptr<int32_t>(), flatten_ids.data_ptr<int32_t>(),
renders.data_ptr<float>(), alphas.data_ptr<float>(),
Expand All @@ -240,6 +258,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> rasterize_to_pixels_fwd_
const torch::Tensor &colors, // [C, N, channels] or [nnz, channels]
const torch::Tensor &opacities, // [C, N] or [nnz]
const at::optional<torch::Tensor> &backgrounds, // [C, channels]
const at::optional<torch::Tensor> &masks, // [C, tile_height, tile_width]
// image size
const uint32_t image_width, const uint32_t image_height, const uint32_t tile_size,
// intersections
Expand All @@ -252,7 +271,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> rasterize_to_pixels_fwd_
#define __GS__CALL_(N) \
case N: \
return call_kernel_with_dim<N>(means2d, conics, colors, opacities, \
backgrounds, image_width, image_height, \
backgrounds, masks, image_width, image_height, \
tile_size, tile_offsets, flatten_ids);

// TODO: an optimization can be done by passing the actual number of channels into
Expand Down
Loading