diff --git a/diff_rast/cuda/csrc/bindings.cu b/diff_rast/cuda/csrc/bindings.cu index 541ebf2cc..429f74e25 100644 --- a/diff_rast/cuda/csrc/bindings.cu +++ b/diff_rast/cuda/csrc/bindings.cu @@ -19,10 +19,7 @@ namespace cg = cooperative_groups; __global__ void compute_cov2d_bounds_forward_kernel( - const unsigned num_pts, - const float *covs2d, - float *conics, - float *radii + const unsigned num_pts, const float *covs2d, float *conics, float *radii ) { unsigned row = cg::this_grid().thread_rank(); if (row >= num_pts) { @@ -31,7 +28,9 @@ __global__ void compute_cov2d_bounds_forward_kernel( int index = row * 3; float3 conic; float radius; - float3 cov2d{(float)covs2d[index], (float)covs2d[index + 1], (float)covs2d[index + 2]}; + float3 cov2d{ + (float)covs2d[index], (float)covs2d[index + 1], (float)covs2d[index + 2] + }; compute_cov2d_bounds(cov2d, conic, radius); conics[index] = conic.x; conics[index + 1] = conic.y; @@ -42,10 +41,11 @@ __global__ void compute_cov2d_bounds_forward_kernel( std::tuple< torch::Tensor, // output conics torch::Tensor> // output radii -compute_cov2d_bounds_forward_tensor(const int num_pts, torch::Tensor covs2d) { +compute_cov2d_bounds_forward_tensor(const int num_pts, torch::Tensor &covs2d) { CHECK_INPUT(covs2d); - torch::Tensor conics = - torch::zeros({num_pts, covs2d.size(1)}, covs2d.options().dtype(torch::kFloat32)); + torch::Tensor conics = torch::zeros( + {num_pts, covs2d.size(1)}, covs2d.options().dtype(torch::kFloat32) + ); torch::Tensor radii = torch::zeros({num_pts, 1}, covs2d.options().dtype(torch::kFloat32)); @@ -63,8 +63,8 @@ compute_cov2d_bounds_forward_tensor(const int num_pts, torch::Tensor covs2d) { torch::Tensor compute_sh_forward_tensor( const unsigned num_points, const unsigned degree, - torch::Tensor viewdirs, - torch::Tensor coeffs + torch::Tensor &viewdirs, + torch::Tensor &coeffs ) { unsigned num_bases = num_sh_bases(degree); if (coeffs.ndimension() != 3 || coeffs.size(0) != num_points || @@ -87,8 +87,8 @@ torch::Tensor compute_sh_forward_tensor( torch::Tensor compute_sh_backward_tensor( const unsigned num_points, const unsigned degree, - torch::Tensor viewdirs, - torch::Tensor v_colors + torch::Tensor &viewdirs, + torch::Tensor &v_colors ) { if (viewdirs.ndimension() != 2 || viewdirs.size(0) != num_points || viewdirs.size(1) != 3) { @@ -122,17 +122,18 @@ std::tuple< torch::Tensor> project_gaussians_forward_tensor( const int num_points, - torch::Tensor means3d, - torch::Tensor scales, + torch::Tensor &means3d, + torch::Tensor &scales, const float glob_scale, - torch::Tensor quats, - torch::Tensor viewmat, - torch::Tensor projmat, + torch::Tensor &quats, + torch::Tensor &viewmat, + torch::Tensor &projmat, const float fx, const float fy, const unsigned img_height, const unsigned img_width, - const std::tuple tile_bounds + const std::tuple tile_bounds, + const float clip_thresh ) { dim3 img_size_dim3; img_size_dim3.x = img_width; @@ -169,6 +170,7 @@ project_gaussians_forward_tensor( fy, img_size_dim3, tile_bounds_dim3, + clip_thresh, // Outputs. cov3d_d.contiguous().data_ptr(), (float2 *)xys_d.contiguous().data_ptr(), @@ -191,21 +193,21 @@ std::tuple< torch::Tensor> project_gaussians_backward_tensor( const int num_points, - torch::Tensor means3d, - torch::Tensor scales, + torch::Tensor &means3d, + torch::Tensor &scales, const float glob_scale, - torch::Tensor quats, - torch::Tensor viewmat, - torch::Tensor projmat, + torch::Tensor &quats, + torch::Tensor &viewmat, + torch::Tensor &projmat, const float fx, const float fy, const unsigned img_height, const unsigned img_width, - torch::Tensor cov3d, - torch::Tensor radii, - torch::Tensor conics, - torch::Tensor v_xy, - torch::Tensor v_conic + torch::Tensor &cov3d, + torch::Tensor &radii, + torch::Tensor &conics, + torch::Tensor &v_xy, + torch::Tensor &v_conic ) { dim3 img_size_dim3; img_size_dim3.x = img_width; @@ -251,4 +253,3 @@ project_gaussians_backward_tensor( return std::make_tuple(v_cov2d, v_cov3d, v_mean3d, v_scale, v_quat); } - diff --git a/diff_rast/cuda/csrc/bindings.h b/diff_rast/cuda/csrc/bindings.h index 090c913de..0b392fdff 100644 --- a/diff_rast/cuda/csrc/bindings.h +++ b/diff_rast/cuda/csrc/bindings.h @@ -16,20 +16,20 @@ std::tuple< torch::Tensor, // output conics torch::Tensor> // output radii -compute_cov2d_bounds_forward_tensor(const int num_pts, torch::Tensor A); +compute_cov2d_bounds_forward_tensor(const int num_pts, torch::Tensor &A); torch::Tensor compute_sh_forward_tensor( unsigned num_points, unsigned degree, - torch::Tensor viewdirs, - torch::Tensor coeffs + torch::Tensor &viewdirs, + torch::Tensor &coeffs ); torch::Tensor compute_sh_backward_tensor( unsigned num_points, unsigned degree, - torch::Tensor viewdirs, - torch::Tensor v_colors + torch::Tensor &viewdirs, + torch::Tensor &v_colors ); std::tuple< @@ -41,17 +41,18 @@ std::tuple< torch::Tensor> project_gaussians_forward_tensor( const int num_points, - torch::Tensor means3d, - torch::Tensor scales, + torch::Tensor &means3d, + torch::Tensor &scales, const float glob_scale, - torch::Tensor quats, - torch::Tensor viewmat, - torch::Tensor projmat, + torch::Tensor &quats, + torch::Tensor &viewmat, + torch::Tensor &projmat, const float fx, const float fy, const unsigned img_height, const unsigned img_width, - const std::tuple tile_bounds + const std::tuple tile_bounds, + const float clip_thresh ); std::tuple< @@ -62,19 +63,19 @@ std::tuple< torch::Tensor> project_gaussians_backward_tensor( const int num_points, - torch::Tensor means3d, - torch::Tensor scales, + torch::Tensor &means3d, + torch::Tensor &scales, const float glob_scale, - torch::Tensor quats, - torch::Tensor viewmat, - torch::Tensor projmat, + torch::Tensor &quats, + torch::Tensor &viewmat, + torch::Tensor &projmat, const float fx, const float fy, const unsigned img_height, const unsigned img_width, - torch::Tensor cov3d, - torch::Tensor radii, - torch::Tensor conics, - torch::Tensor v_xy, - torch::Tensor v_conic + torch::Tensor &cov3d, + torch::Tensor &radii, + torch::Tensor &conics, + torch::Tensor &v_xy, + torch::Tensor &v_conic ); diff --git a/diff_rast/cuda/csrc/forward.cu b/diff_rast/cuda/csrc/forward.cu index aebf60816..aa12b771b 100644 --- a/diff_rast/cuda/csrc/forward.cu +++ b/diff_rast/cuda/csrc/forward.cu @@ -1,11 +1,11 @@ #include "forward.cuh" #include "helpers.cuh" +#include #include #include #include #include #include -#include namespace cg = cooperative_groups; @@ -23,6 +23,7 @@ __global__ void project_gaussians_forward_kernel( const float fy, const dim3 img_size, const dim3 tile_bounds, + const float clip_thresh, float *covs3d, float2 *xys, float *depths, @@ -41,7 +42,7 @@ __global__ void project_gaussians_forward_kernel( // printf("p_world %d %.2f %.2f %.2f\n", idx, p_world.x, p_world.y, // p_world.z); float3 p_view; - if (clip_near_plane(p_world, viewmat, p_view)) { + if (clip_near_plane(p_world, viewmat, p_view, clip_thresh)) { // printf("%d is out of frustum z %.2f, returning\n", idx, p_view.z); return; } @@ -59,7 +60,9 @@ __global__ void project_gaussians_forward_kernel( // project to 2d with ewa approximation float tan_fovx = 0.5 * img_size.x / fx; float tan_fovy = 0.5 * img_size.y / fy; - float3 cov2d = project_cov3d_ewa(p_world, cur_cov3d, viewmat, fx, fy, tan_fovx, tan_fovy); + float3 cov2d = project_cov3d_ewa( + p_world, cur_cov3d, viewmat, fx, fy, tan_fovx, tan_fovy + ); // printf("cov2d %d, %.2f %.2f %.2f\n", idx, cov2d.x, cov2d.y, cov2d.z); float3 conic; @@ -104,6 +107,7 @@ void project_gaussians_forward_impl( const float fy, const dim3 img_size, const dim3 tile_bounds, + const float clip_thresh, float *covs3d, float2 *xys, float *depths, @@ -125,6 +129,7 @@ void project_gaussians_forward_impl( fy, img_size, tile_bounds, + clip_thresh, covs3d, xys, depths, @@ -167,7 +172,7 @@ __global__ void map_gaussian_to_intersects( // isect_id is tile ID and depth as int32 int64_t tile_id = i * tile_bounds.x + j; // tile within image isect_ids[cur_idx] = (tile_id << 32) | depth_id; // tile | depth id - gaussian_ids[cur_idx] = idx; // 3D gaussian id + gaussian_ids[cur_idx] = idx; // 3D gaussian id ++cur_idx; // handles gaussians that hit more than one tile } } @@ -209,7 +214,8 @@ void compute_cumulative_intersects( int32_t &num_intersects, int32_t *cum_tiles_hit ) { - // ref: https://nvlabs.github.io/cub/structcub_1_1_device_scan.html#a9416ac1ea26f9fde669d83ddc883795a + // ref: + // https://nvlabs.github.io/cub/structcub_1_1_device_scan.html#a9416ac1ea26f9fde669d83ddc883795a // allocate sum workspace void *sum_ws = nullptr; size_t sum_ws_bytes; @@ -269,7 +275,7 @@ void bin_and_sort_gaussians( ); // sort intersections by ascending tile ID and depth with RadixSort - int32_t max_tile_id = (int32_t) (tile_bounds.x * tile_bounds.y); + int32_t max_tile_id = (int32_t)(tile_bounds.x * tile_bounds.y); int msb = 32 - __builtin_clz(max_tile_id) + 1; // allocate workspace memory void *sort_ws = nullptr; @@ -300,9 +306,9 @@ void bin_and_sort_gaussians( cudaFree(sort_ws); // get the start and end indices for the gaussians in each tile - // printf("launching tile binning %d %d\n", - // (num_intersects + N_THREADS - 1) / N_THREADS, - // N_THREADS); + // printf("launching tile binning %d %d\n", + // (num_intersects + N_THREADS - 1) / N_THREADS, + // N_THREADS); get_tile_bin_edges<<< (num_intersects + N_THREADS - 1) / N_THREADS, N_THREADS>>>(num_intersects, isect_ids_sorted, tile_bins); @@ -316,7 +322,7 @@ void bin_and_sort_gaussians( // kernel function for rasterizing each tile // each thread treats a single pixel // each thread group uses the same gaussian data in a tile -template +template __global__ void rasterize_forward_kernel( const dim3 tile_bounds, const dim3 img_size, @@ -344,7 +350,7 @@ __global__ void rasterize_forward_kernel( if (i >= img_size.y || j >= img_size.x) { return; } - + // which gaussians to look through in this tile int2 range = tile_bins[tile_id]; float3 conic; @@ -352,7 +358,8 @@ __global__ void rasterize_forward_kernel( float sigma, opac, alpha, vis, next_T; float T = 1.f; - // iterate over all gaussians and apply rendering EWA equation (e.q. 2 from paper) + // iterate over all gaussians and apply rendering EWA equation (e.q. 2 from + // paper) int idx; int32_t g; for (idx = range.x; idx < range.y; ++idx) { @@ -361,8 +368,9 @@ __global__ void rasterize_forward_kernel( center = xys[g]; delta = {center.x - px, center.y - py}; - // Mahalanobis distance (here referred to as sigma) measures how many standard deviations away distance delta is. - // sigma = -0.5(d.T * conic * d) + // Mahalanobis distance (here referred to as sigma) measures how many + // standard deviations away distance delta is. sigma = -0.5(d.T * conic + // * d) sigma = 0.5f * (conic.x * delta.x * delta.x + conic.z * delta.y * delta.y) + conic.y * delta.x * delta.y; @@ -379,8 +387,9 @@ __global__ void rasterize_forward_kernel( } next_T = T * (1.f - alpha); if (next_T <= 1e-4f) { - // we want to render the last gaussian that contributes and note that here idx > range.x so we don't underflow - idx -= 1; + // we want to render the last gaussian that contributes and note + // that here idx > range.x so we don't underflow + idx -= 1; break; } vis = alpha * T; @@ -392,7 +401,7 @@ __global__ void rasterize_forward_kernel( final_Ts[pix_id] = T; // transmittance at last gaussian in this pixel final_index[pix_id] = idx; // index of in bin of last gaussian in this pixel for (int c = 0; c < CHANNELS; ++c) { - out_img[CHANNELS * pix_id + c] += T * background[c]; + out_img[CHANNELS * pix_id + c] += T * background[c]; } } @@ -410,9 +419,9 @@ void rasterize_forward_impl( float *final_Ts, int *final_index, float *out_img, - const float* background + const float *background ) { - rasterize_forward_kernel<3> <<>>( + rasterize_forward_kernel<3><<>>( tile_bounds, img_size, gaussian_ids_sorted, diff --git a/diff_rast/cuda/csrc/forward.cuh b/diff_rast/cuda/csrc/forward.cuh index ce2364235..34ab8f84d 100644 --- a/diff_rast/cuda/csrc/forward.cuh +++ b/diff_rast/cuda/csrc/forward.cuh @@ -14,6 +14,7 @@ void project_gaussians_forward_impl( const float fy, const dim3 img_size, const dim3 tile_bounds, + const float clip_thresh, float *covs3d, float2 *xys, float *depths, diff --git a/diff_rast/cuda/csrc/helpers.cuh b/diff_rast/cuda/csrc/helpers.cuh index 3b61b8625..884857d94 100644 --- a/diff_rast/cuda/csrc/helpers.cuh +++ b/diff_rast/cuda/csrc/helpers.cuh @@ -208,9 +208,9 @@ scale_to_mat(const float3 scale, const float glob_scale) { // device helper for culling near points inline __host__ __device__ bool -clip_near_plane(const float3 p, const float *viewmat, float3 &p_view) { +clip_near_plane(const float3 p, const float *viewmat, float3 &p_view, float thresh) { p_view = transform_4x3(viewmat, p); - if (p_view.z <= 0.01f) { + if (p_view.z <= thresh) { return true; } return false; diff --git a/diff_rast/project_gaussians.py b/diff_rast/project_gaussians.py index 88ed6c327..13ad4990e 100644 --- a/diff_rast/project_gaussians.py +++ b/diff_rast/project_gaussians.py @@ -40,6 +40,7 @@ def forward( img_height: int, img_width: int, tile_bounds: Tuple[int, int, int], + clip_thresh:float=0.01 ): num_points = means3d.shape[-2] @@ -63,6 +64,7 @@ def forward( img_height, img_width, tile_bounds, + clip_thresh, ) # Save non-tensors. @@ -100,7 +102,13 @@ def backward(ctx, v_xys, v_depths, v_radii, v_conics, v_num_tiles_hit, v_cov3d): conics, ) = ctx.saved_tensors - (v_cov2d, v_cov3d, v_mean3d, v_scale, v_quat,) = _C.project_gaussians_backward( + ( + v_cov2d, + v_cov3d, + v_mean3d, + v_scale, + v_quat, + ) = _C.project_gaussians_backward( ctx.num_points, means3d, scales,