Skip to content

Commit

Permalink
adding adjustable clip plane, making bindings take tensor& not tensor…
Browse files Browse the repository at this point in the history
… for efficiency
  • Loading branch information
vye16 committed Sep 29, 2023
1 parent 9e64677 commit 35b0432
Show file tree
Hide file tree
Showing 6 changed files with 92 additions and 72 deletions.
59 changes: 30 additions & 29 deletions diff_rast/cuda/csrc/bindings.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,7 @@
namespace cg = cooperative_groups;

__global__ void compute_cov2d_bounds_forward_kernel(
const unsigned num_pts,
const float *covs2d,
float *conics,
float *radii
const unsigned num_pts, const float *covs2d, float *conics, float *radii
) {
unsigned row = cg::this_grid().thread_rank();
if (row >= num_pts) {
Expand All @@ -31,7 +28,9 @@ __global__ void compute_cov2d_bounds_forward_kernel(
int index = row * 3;
float3 conic;
float radius;
float3 cov2d{(float)covs2d[index], (float)covs2d[index + 1], (float)covs2d[index + 2]};
float3 cov2d{
(float)covs2d[index], (float)covs2d[index + 1], (float)covs2d[index + 2]
};
compute_cov2d_bounds(cov2d, conic, radius);
conics[index] = conic.x;
conics[index + 1] = conic.y;
Expand All @@ -42,10 +41,11 @@ __global__ void compute_cov2d_bounds_forward_kernel(
std::tuple<
torch::Tensor, // output conics
torch::Tensor> // output radii
compute_cov2d_bounds_forward_tensor(const int num_pts, torch::Tensor covs2d) {
compute_cov2d_bounds_forward_tensor(const int num_pts, torch::Tensor &covs2d) {
CHECK_INPUT(covs2d);
torch::Tensor conics =
torch::zeros({num_pts, covs2d.size(1)}, covs2d.options().dtype(torch::kFloat32));
torch::Tensor conics = torch::zeros(
{num_pts, covs2d.size(1)}, covs2d.options().dtype(torch::kFloat32)
);
torch::Tensor radii =
torch::zeros({num_pts, 1}, covs2d.options().dtype(torch::kFloat32));

Expand All @@ -63,8 +63,8 @@ compute_cov2d_bounds_forward_tensor(const int num_pts, torch::Tensor covs2d) {
torch::Tensor compute_sh_forward_tensor(
const unsigned num_points,
const unsigned degree,
torch::Tensor viewdirs,
torch::Tensor coeffs
torch::Tensor &viewdirs,
torch::Tensor &coeffs
) {
unsigned num_bases = num_sh_bases(degree);
if (coeffs.ndimension() != 3 || coeffs.size(0) != num_points ||
Expand All @@ -87,8 +87,8 @@ torch::Tensor compute_sh_forward_tensor(
torch::Tensor compute_sh_backward_tensor(
const unsigned num_points,
const unsigned degree,
torch::Tensor viewdirs,
torch::Tensor v_colors
torch::Tensor &viewdirs,
torch::Tensor &v_colors
) {
if (viewdirs.ndimension() != 2 || viewdirs.size(0) != num_points ||
viewdirs.size(1) != 3) {
Expand Down Expand Up @@ -122,17 +122,18 @@ std::tuple<
torch::Tensor>
project_gaussians_forward_tensor(
const int num_points,
torch::Tensor means3d,
torch::Tensor scales,
torch::Tensor &means3d,
torch::Tensor &scales,
const float glob_scale,
torch::Tensor quats,
torch::Tensor viewmat,
torch::Tensor projmat,
torch::Tensor &quats,
torch::Tensor &viewmat,
torch::Tensor &projmat,
const float fx,
const float fy,
const unsigned img_height,
const unsigned img_width,
const std::tuple<int, int, int> tile_bounds
const std::tuple<int, int, int> tile_bounds,
const float clip_thresh
) {
dim3 img_size_dim3;
img_size_dim3.x = img_width;
Expand Down Expand Up @@ -169,6 +170,7 @@ project_gaussians_forward_tensor(
fy,
img_size_dim3,
tile_bounds_dim3,
clip_thresh,
// Outputs.
cov3d_d.contiguous().data_ptr<float>(),
(float2 *)xys_d.contiguous().data_ptr<float>(),
Expand All @@ -191,21 +193,21 @@ std::tuple<
torch::Tensor>
project_gaussians_backward_tensor(
const int num_points,
torch::Tensor means3d,
torch::Tensor scales,
torch::Tensor &means3d,
torch::Tensor &scales,
const float glob_scale,
torch::Tensor quats,
torch::Tensor viewmat,
torch::Tensor projmat,
torch::Tensor &quats,
torch::Tensor &viewmat,
torch::Tensor &projmat,
const float fx,
const float fy,
const unsigned img_height,
const unsigned img_width,
torch::Tensor cov3d,
torch::Tensor radii,
torch::Tensor conics,
torch::Tensor v_xy,
torch::Tensor v_conic
torch::Tensor &cov3d,
torch::Tensor &radii,
torch::Tensor &conics,
torch::Tensor &v_xy,
torch::Tensor &v_conic
) {
dim3 img_size_dim3;
img_size_dim3.x = img_width;
Expand Down Expand Up @@ -251,4 +253,3 @@ project_gaussians_backward_tensor(

return std::make_tuple(v_cov2d, v_cov3d, v_mean3d, v_scale, v_quat);
}

43 changes: 22 additions & 21 deletions diff_rast/cuda/csrc/bindings.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,20 @@
std::tuple<
torch::Tensor, // output conics
torch::Tensor> // output radii
compute_cov2d_bounds_forward_tensor(const int num_pts, torch::Tensor A);
compute_cov2d_bounds_forward_tensor(const int num_pts, torch::Tensor &A);

torch::Tensor compute_sh_forward_tensor(
unsigned num_points,
unsigned degree,
torch::Tensor viewdirs,
torch::Tensor coeffs
torch::Tensor &viewdirs,
torch::Tensor &coeffs
);

torch::Tensor compute_sh_backward_tensor(
unsigned num_points,
unsigned degree,
torch::Tensor viewdirs,
torch::Tensor v_colors
torch::Tensor &viewdirs,
torch::Tensor &v_colors
);

std::tuple<
Expand All @@ -41,17 +41,18 @@ std::tuple<
torch::Tensor>
project_gaussians_forward_tensor(
const int num_points,
torch::Tensor means3d,
torch::Tensor scales,
torch::Tensor &means3d,
torch::Tensor &scales,
const float glob_scale,
torch::Tensor quats,
torch::Tensor viewmat,
torch::Tensor projmat,
torch::Tensor &quats,
torch::Tensor &viewmat,
torch::Tensor &projmat,
const float fx,
const float fy,
const unsigned img_height,
const unsigned img_width,
const std::tuple<int, int, int> tile_bounds
const std::tuple<int, int, int> tile_bounds,
const float clip_thresh
);

std::tuple<
Expand All @@ -62,19 +63,19 @@ std::tuple<
torch::Tensor>
project_gaussians_backward_tensor(
const int num_points,
torch::Tensor means3d,
torch::Tensor scales,
torch::Tensor &means3d,
torch::Tensor &scales,
const float glob_scale,
torch::Tensor quats,
torch::Tensor viewmat,
torch::Tensor projmat,
torch::Tensor &quats,
torch::Tensor &viewmat,
torch::Tensor &projmat,
const float fx,
const float fy,
const unsigned img_height,
const unsigned img_width,
torch::Tensor cov3d,
torch::Tensor radii,
torch::Tensor conics,
torch::Tensor v_xy,
torch::Tensor v_conic
torch::Tensor &cov3d,
torch::Tensor &radii,
torch::Tensor &conics,
torch::Tensor &v_xy,
torch::Tensor &v_conic
);
47 changes: 28 additions & 19 deletions diff_rast/cuda/csrc/forward.cu
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
#include "forward.cuh"
#include "helpers.cuh"
#include <algorithm>
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
#include <cub/cub.cuh>
#include <cub/device/device_radix_sort.cuh>
#include <iostream>
#include <algorithm>

namespace cg = cooperative_groups;

Expand All @@ -23,6 +23,7 @@ __global__ void project_gaussians_forward_kernel(
const float fy,
const dim3 img_size,
const dim3 tile_bounds,
const float clip_thresh,
float *covs3d,
float2 *xys,
float *depths,
Expand All @@ -41,7 +42,7 @@ __global__ void project_gaussians_forward_kernel(
// printf("p_world %d %.2f %.2f %.2f\n", idx, p_world.x, p_world.y,
// p_world.z);
float3 p_view;
if (clip_near_plane(p_world, viewmat, p_view)) {
if (clip_near_plane(p_world, viewmat, p_view, clip_thresh)) {
// printf("%d is out of frustum z %.2f, returning\n", idx, p_view.z);
return;
}
Expand All @@ -59,7 +60,9 @@ __global__ void project_gaussians_forward_kernel(
// project to 2d with ewa approximation
float tan_fovx = 0.5 * img_size.x / fx;
float tan_fovy = 0.5 * img_size.y / fy;
float3 cov2d = project_cov3d_ewa(p_world, cur_cov3d, viewmat, fx, fy, tan_fovx, tan_fovy);
float3 cov2d = project_cov3d_ewa(
p_world, cur_cov3d, viewmat, fx, fy, tan_fovx, tan_fovy
);
// printf("cov2d %d, %.2f %.2f %.2f\n", idx, cov2d.x, cov2d.y, cov2d.z);

float3 conic;
Expand Down Expand Up @@ -104,6 +107,7 @@ void project_gaussians_forward_impl(
const float fy,
const dim3 img_size,
const dim3 tile_bounds,
const float clip_thresh,
float *covs3d,
float2 *xys,
float *depths,
Expand All @@ -125,6 +129,7 @@ void project_gaussians_forward_impl(
fy,
img_size,
tile_bounds,
clip_thresh,
covs3d,
xys,
depths,
Expand Down Expand Up @@ -167,7 +172,7 @@ __global__ void map_gaussian_to_intersects(
// isect_id is tile ID and depth as int32
int64_t tile_id = i * tile_bounds.x + j; // tile within image
isect_ids[cur_idx] = (tile_id << 32) | depth_id; // tile | depth id
gaussian_ids[cur_idx] = idx; // 3D gaussian id
gaussian_ids[cur_idx] = idx; // 3D gaussian id
++cur_idx; // handles gaussians that hit more than one tile
}
}
Expand Down Expand Up @@ -209,7 +214,8 @@ void compute_cumulative_intersects(
int32_t &num_intersects,
int32_t *cum_tiles_hit
) {
// ref: https://nvlabs.github.io/cub/structcub_1_1_device_scan.html#a9416ac1ea26f9fde669d83ddc883795a
// ref:
// https://nvlabs.github.io/cub/structcub_1_1_device_scan.html#a9416ac1ea26f9fde669d83ddc883795a
// allocate sum workspace
void *sum_ws = nullptr;
size_t sum_ws_bytes;
Expand Down Expand Up @@ -269,7 +275,7 @@ void bin_and_sort_gaussians(
);

// sort intersections by ascending tile ID and depth with RadixSort
int32_t max_tile_id = (int32_t) (tile_bounds.x * tile_bounds.y);
int32_t max_tile_id = (int32_t)(tile_bounds.x * tile_bounds.y);
int msb = 32 - __builtin_clz(max_tile_id) + 1;
// allocate workspace memory
void *sort_ws = nullptr;
Expand Down Expand Up @@ -300,9 +306,9 @@ void bin_and_sort_gaussians(
cudaFree(sort_ws);

// get the start and end indices for the gaussians in each tile
// printf("launching tile binning %d %d\n",
// (num_intersects + N_THREADS - 1) / N_THREADS,
// N_THREADS);
// printf("launching tile binning %d %d\n",
// (num_intersects + N_THREADS - 1) / N_THREADS,
// N_THREADS);
get_tile_bin_edges<<<
(num_intersects + N_THREADS - 1) / N_THREADS,
N_THREADS>>>(num_intersects, isect_ids_sorted, tile_bins);
Expand All @@ -316,7 +322,7 @@ void bin_and_sort_gaussians(
// kernel function for rasterizing each tile
// each thread treats a single pixel
// each thread group uses the same gaussian data in a tile
template<int CHANNELS>
template <int CHANNELS>
__global__ void rasterize_forward_kernel(
const dim3 tile_bounds,
const dim3 img_size,
Expand Down Expand Up @@ -344,15 +350,16 @@ __global__ void rasterize_forward_kernel(
if (i >= img_size.y || j >= img_size.x) {
return;
}

// which gaussians to look through in this tile
int2 range = tile_bins[tile_id];
float3 conic;
float2 center, delta;
float sigma, opac, alpha, vis, next_T;
float T = 1.f;

// iterate over all gaussians and apply rendering EWA equation (e.q. 2 from paper)
// iterate over all gaussians and apply rendering EWA equation (e.q. 2 from
// paper)
int idx;
int32_t g;
for (idx = range.x; idx < range.y; ++idx) {
Expand All @@ -361,8 +368,9 @@ __global__ void rasterize_forward_kernel(
center = xys[g];
delta = {center.x - px, center.y - py};

// Mahalanobis distance (here referred to as sigma) measures how many standard deviations away distance delta is.
// sigma = -0.5(d.T * conic * d)
// Mahalanobis distance (here referred to as sigma) measures how many
// standard deviations away distance delta is. sigma = -0.5(d.T * conic
// * d)
sigma =
0.5f * (conic.x * delta.x * delta.x + conic.z * delta.y * delta.y) +
conic.y * delta.x * delta.y;
Expand All @@ -379,8 +387,9 @@ __global__ void rasterize_forward_kernel(
}
next_T = T * (1.f - alpha);
if (next_T <= 1e-4f) {
// we want to render the last gaussian that contributes and note that here idx > range.x so we don't underflow
idx -= 1;
// we want to render the last gaussian that contributes and note
// that here idx > range.x so we don't underflow
idx -= 1;
break;
}
vis = alpha * T;
Expand All @@ -392,7 +401,7 @@ __global__ void rasterize_forward_kernel(
final_Ts[pix_id] = T; // transmittance at last gaussian in this pixel
final_index[pix_id] = idx; // index of in bin of last gaussian in this pixel
for (int c = 0; c < CHANNELS; ++c) {
out_img[CHANNELS * pix_id + c] += T * background[c];
out_img[CHANNELS * pix_id + c] += T * background[c];
}
}

Expand All @@ -410,9 +419,9 @@ void rasterize_forward_impl(
float *final_Ts,
int *final_index,
float *out_img,
const float* background
const float *background
) {
rasterize_forward_kernel<3> <<<tile_bounds, block>>>(
rasterize_forward_kernel<3><<<tile_bounds, block>>>(
tile_bounds,
img_size,
gaussian_ids_sorted,
Expand Down
1 change: 1 addition & 0 deletions diff_rast/cuda/csrc/forward.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ void project_gaussians_forward_impl(
const float fy,
const dim3 img_size,
const dim3 tile_bounds,
const float clip_thresh,
float *covs3d,
float2 *xys,
float *depths,
Expand Down
Loading

0 comments on commit 35b0432

Please sign in to comment.