-
Notifications
You must be signed in to change notification settings - Fork 312
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Ruilong Li
committed
Aug 7, 2024
1 parent
76ca887
commit 7c78f4b
Showing
2 changed files
with
38 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,6 +25,7 @@ __global__ void rasterize_to_pixels_fwd_kernel( | |
const int32_t *__restrict__ flatten_ids, // [n_isects] | ||
S *__restrict__ render_colors, // [C, image_height, image_width, COLOR_DIM] | ||
S *__restrict__ render_alphas, // [C, image_height, image_width, 1] | ||
S *__restrict__ render_depths, // [C, image_height, image_width, 1] optional | ||
int32_t *__restrict__ last_ids // [C, image_height, image_width] | ||
) { | ||
// each thread draws one pixel, but also timeshares caching gaussians in a | ||
|
@@ -39,6 +40,9 @@ __global__ void rasterize_to_pixels_fwd_kernel( | |
tile_offsets += camera_id * tile_height * tile_width; | ||
render_colors += camera_id * image_height * image_width * COLOR_DIM; | ||
render_alphas += camera_id * image_height * image_width; | ||
if (render_depths != nullptr) { | ||
render_depths += camera_id * image_height * image_width; | ||
} | ||
last_ids += camera_id * image_height * image_width; | ||
if (backgrounds != nullptr) { | ||
backgrounds += camera_id * COLOR_DIM; | ||
|
@@ -84,6 +88,7 @@ __global__ void rasterize_to_pixels_fwd_kernel( | |
uint32_t tr = block.thread_rank(); | ||
|
||
S pix_out[COLOR_DIM] = {0.f}; | ||
S depth_out = 0.f; | ||
for (uint32_t b = 0; b < num_batches; ++b) { | ||
// resync all threads before beginning next batch | ||
// end early if entire tile is done | ||
|
@@ -140,10 +145,18 @@ __global__ void rasterize_to_pixels_fwd_kernel( | |
int32_t g = id_batch[t]; | ||
const S vis = alpha * T; | ||
const S *c_ptr = colors + g * COLOR_DIM; | ||
// accumulate color | ||
PRAGMA_UNROLL | ||
for (uint32_t k = 0; k < COLOR_DIM; ++k) { | ||
pix_out[k] += c_ptr[k] * vis; | ||
} | ||
// accumulate depth | ||
if (render_depths != nullptr) { | ||
S depth = | ||
mean2d.z + | ||
(conic02 * (mean2d.x - px) + conic12 * (mean2d.y - py)) / conic22; | ||
depth_out += depth * vis; | ||
} | ||
This comment has been minimized.
Sorry, something went wrong.
This comment has been minimized.
Sorry, something went wrong.
This comment has been minimized.
Sorry, something went wrong.
liruilong940607
Collaborator
|
||
cur_idx = batch_start + t; | ||
|
||
T = next_T; | ||
|
@@ -162,13 +175,17 @@ __global__ void rasterize_to_pixels_fwd_kernel( | |
render_colors[pix_id * COLOR_DIM + k] = | ||
backgrounds == nullptr ? pix_out[k] : (pix_out[k] + T * backgrounds[k]); | ||
} | ||
if (render_depths != nullptr) { | ||
render_depths[pix_id] = depth_out; | ||
} | ||
// index in bin of last gaussian in this pixel | ||
last_ids[pix_id] = static_cast<int32_t>(cur_idx); | ||
} | ||
} | ||
|
||
template <uint32_t CDIM> | ||
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> call_kernel_with_dim( | ||
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> | ||
call_kernel_with_dim( | ||
// Gaussian parameters | ||
const torch::Tensor &means2d, // [C, N, 3] or [nnz, 3] | ||
const torch::Tensor &conics, // [C, N, 6] or [nnz, 6] | ||
|
@@ -179,8 +196,8 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> call_kernel_with_dim( | |
const uint32_t image_width, const uint32_t image_height, const uint32_t tile_size, | ||
// intersections | ||
const torch::Tensor &tile_offsets, // [C, tile_height, tile_width] | ||
const torch::Tensor &flatten_ids // [n_isects] | ||
) { | ||
const torch::Tensor &flatten_ids, // [n_isects] | ||
bool calc_depth) { | ||
DEVICE_GUARD(means2d); | ||
CHECK_INPUT(means2d); | ||
CHECK_INPUT(conics); | ||
|
@@ -209,6 +226,11 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> call_kernel_with_dim( | |
means2d.options().dtype(torch::kFloat32)); | ||
torch::Tensor alphas = torch::empty({C, image_height, image_width, 1}, | ||
means2d.options().dtype(torch::kFloat32)); | ||
torch::Tensor depths; | ||
if (calc_depth) { | ||
depths = torch::empty({C, image_height, image_width, 1}, | ||
means2d.options().dtype(torch::kFloat32)); | ||
} | ||
torch::Tensor last_ids = torch::empty({C, image_height, image_width}, | ||
means2d.options().dtype(torch::kInt32)); | ||
|
||
|
@@ -236,12 +258,14 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> call_kernel_with_dim( | |
image_width, image_height, tile_size, tile_width, tile_height, | ||
tile_offsets.data_ptr<int32_t>(), flatten_ids.data_ptr<int32_t>(), | ||
renders.data_ptr<float>(), alphas.data_ptr<float>(), | ||
calc_depth ? depths.data_ptr<float>() : nullptr, | ||
last_ids.data_ptr<int32_t>()); | ||
|
||
return std::make_tuple(renders, alphas, last_ids); | ||
return std::make_tuple(renders, alphas, depths, last_ids); | ||
} | ||
|
||
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> rasterize_to_pixels_fwd_tensor( | ||
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> | ||
rasterize_to_pixels_fwd_tensor( | ||
// Gaussian parameters | ||
const torch::Tensor &means2d, // [C, N, 3] or [nnz, 3] | ||
const torch::Tensor &conics, // [C, N, 6] or [nnz, 6] | ||
|
@@ -252,16 +276,16 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> rasterize_to_pixels_fwd_ | |
const uint32_t image_width, const uint32_t image_height, const uint32_t tile_size, | ||
// intersections | ||
const torch::Tensor &tile_offsets, // [C, tile_height, tile_width] | ||
const torch::Tensor &flatten_ids // [n_isects] | ||
) { | ||
const torch::Tensor &flatten_ids, // [n_isects] | ||
bool calc_depth) { | ||
CHECK_INPUT(colors); | ||
uint32_t channels = colors.size(-1); | ||
|
||
#define __GS__CALL_(N) \ | ||
case N: \ | ||
return call_kernel_with_dim<N>(means2d, conics, colors, opacities, \ | ||
backgrounds, image_width, image_height, \ | ||
tile_size, tile_offsets, flatten_ids); | ||
return call_kernel_with_dim<N>( \ | ||
means2d, conics, colors, opacities, backgrounds, image_width, \ | ||
image_height, tile_size, tile_offsets, flatten_ids, calc_depth); | ||
|
||
// TODO: an optimization can be done by passing the actual number of channels into | ||
// the kernel functions and avoid necessary global memory writes. This requires | ||
|
Does this depth formulation from a gaussian have any reference in the literature?