From bbc9e98fa9add8b355b91b69096ec3b1271a9f46 Mon Sep 17 00:00:00 2001 From: DylanWaken <109344556+DylanWaken@users.noreply.github.com> Date: Sun, 29 Sep 2024 20:48:34 -0700 Subject: [PATCH] A set of comments and annotations to 2DGS CUDA kernels (#429) * annotations for cuda kernels * remove questions --- .../csrc/fully_fused_projection_2dgs_fwd.cu | 150 +++++++-- .../cuda/csrc/rasterize_to_pixels_2dgs_bwd.cu | 285 ++++++++++++++---- .../cuda/csrc/rasterize_to_pixels_2dgs_fwd.cu | 182 ++++++++--- 3 files changed, 488 insertions(+), 129 deletions(-) diff --git a/gsplat/cuda/csrc/fully_fused_projection_2dgs_fwd.cu b/gsplat/cuda/csrc/fully_fused_projection_2dgs_fwd.cu index 08726d0f5..d9beedc36 100644 --- a/gsplat/cuda/csrc/fully_fused_projection_2dgs_fwd.cu +++ b/gsplat/cuda/csrc/fully_fused_projection_2dgs_fwd.cu @@ -20,37 +20,61 @@ template __global__ void fully_fused_projection_fwd_2dgs_kernel( const uint32_t C, const uint32_t N, - const T *__restrict__ means, // [N, 3] - const T *__restrict__ quats, // [N, 4] - const T *__restrict__ scales, // [N, 3] - const T *__restrict__ viewmats, // [C, 4, 4] - const T *__restrict__ Ks, // [C, 3, 3] - const int32_t image_width, - const int32_t image_height, - const T near_plane, - const T far_plane, - const T radius_clip, + const T *__restrict__ means, // [N, 3]: Gaussian means. (i.e. source points) + const T *__restrict__ quats, // [N, 4]: Quaternions (No need to be normalized): This is the rotation component (for 2D) + const T *__restrict__ scales, // [N, 3]: Scales. [N, 3] scales for x, y, z + const T *__restrict__ viewmats, // [C, 4, 4]: Camera-to-World coordinate mat + // [R t] + // [0 1] + const T *__restrict__ Ks, // [C, 3, 3]: Projective transformation matrix + // [f_x 0 c_x] + // [0 f_y c_y] + // [0 0 1] : f_x, f_y are focal lengths, c_x, c_y is coords for camera center on screen space + const int32_t image_width, // Image width pixels + const int32_t image_height, // Image height pixels + const T near_plane, // Near clipping plane (for finite range used in z sorting) + const T far_plane, // Far clipping plane (for finite range used in z sorting) + const T radius_clip, // Radius clipping threshold (through away small primitives) // outputs - int32_t *__restrict__ radii, // [C, N] - T *__restrict__ means2d, // [C, N, 2] - T *__restrict__ depths, // [C, N] - T *__restrict__ ray_transforms, // [C, N, 3, 3] - T *__restrict__ normals // [C, N, 3] + int32_t *__restrict__ radii, // [C, N] The maximum radius of the projected Gaussians in pixel unit. Int32 tensor of shape [C, N]. + T *__restrict__ means2d, // [C, N, 2] 2D means of the projected Gaussians. + T *__restrict__ depths, // [C, N] The z-depth of the projected Gaussians. + T *__restrict__ ray_transforms, // [C, N, 3, 3] Transformation matrices that transform xy-planes in pixel spaces into splat coordinates (WH)^T in equation (9) in paper + T *__restrict__ normals // [C, N, 3] The normals in camera spaces. ) { + + /** + * =============================================== + * Initialize execution and threading variables: + * idx: global thread index + * cid: camera id (N is the total number of primitives, C is the number of cameras) + * gid: gaussian id (N is the total number of primitives, C is the number of cameras) + + * THIS KERNEL LAUNCHES PER PRIMITIVE PER CAMERA i.e. C*N THREADS IN TOTAL + * =============================================== + */ + // parallelize over C * N. - uint32_t idx = cg::this_grid().thread_rank(); + uint32_t idx = cg::this_grid().thread_rank(); // get the thread index from grid if (idx >= C * N) { return; } const uint32_t cid = idx / N; // camera id const uint32_t gid = idx % N; // gaussian id + /** + * =============================================== + * Load data and put together camera rotation / translation + * =============================================== + */ + // shift pointers to the current camera and gaussian - means += gid * 3; - viewmats += cid * 16; - Ks += cid * 9; + means += gid * 3; // find the mean of the primitive this thread is responsible for + viewmats += cid * 16; // step 4x4 camera matrix + Ks += cid * 9; // step 3x3 intrinsic matrix // glm is column-major but input is row-major + // rotation component of the camera. Explicit Transpose mat3 R = mat3( viewmats[0], viewmats[4], @@ -62,39 +86,103 @@ __global__ void fully_fused_projection_fwd_2dgs_kernel( viewmats[6], viewmats[10] // 3rd column ); + // translation component of the camera vec3 t = vec3(viewmats[3], viewmats[7], viewmats[11]); + /** + * =============================================== + * Build ray transformation matrix from Primitive to Camera + * in the original paper, q_ray [xz, yz, z, 1] = WH * q_uv : [u,v,1,1] + * + * Thus: RS_camera = R * H(P->W) + + * Since H matrix (4x4) is defined as: + * [v_x v_y 0_vec3 t] + * [0 0 0 1] + * + * thus RS_Camera defined as R * [v_x v_y 0], which gives + * [R⋅v_x R⋅v_y 0] + * Thus the only non zero terms will be the first two columns of R + * + * This gives the "affine rotation component" from uv to camera space as RS_camera + * + * the final addition component will be mean_c, which is the center of primitive in camera space, as + * q_cam = RS_camera * q_uv + mean_c + * + * Like with homogeneous coordinates. if we encode incoming 2d points as [u,v,1], we can have: + * q_cam = [RS_camera[0,1] | mean_c] * [u,v,1] + * =============================================== + */ + // transform Gaussian center to camera space vec3 mean_c; pos_world_to_cam(R, t, glm::make_vec3(means), mean_c); + + // return this thread for overly small primitives if (mean_c.z < near_plane || mean_c.z > far_plane) { radii[idx] = 0; return; } - // build ray transformation matrix and transform from world space to camera - // space quats += gid * 4; scales += gid * 3; mat3 RS_camera = R * quat_to_rotmat(glm::make_vec4(quats)) * - mat3(scales[0], 0.0, 0.0, 0.0, scales[1], 0.0, 0.0, 0.0, 1.0); + mat3(scales[0], 0.0 , 0.0, + 0.0 , scales[1], 0.0, + 0.0 , 0.0 , 1.0); mat3 WH = mat3(RS_camera[0], RS_camera[1], mean_c); + // projective transformation matrix: Camera -> Screen + // when write in this order, the matrix is actually K^T as glm will read it in column major order + // [Ks[0], 0, 0] + // [0, Ks[4], 0] + // [Ks[2], Ks[5], 1] mat3 world_2_pix = - mat3(Ks[0], 0.0, Ks[2], 0.0, Ks[4], Ks[5], 0.0, 0.0, 1.0); + mat3(Ks[0], 0.0 , Ks[2], + 0.0 , Ks[4], Ks[5], + 0.0 , 0.0 , 1.0); + + // WH is defined as [R⋅v_x, R⋅v_y, mean_c]: q_uv = [u,v,-1] -> q_cam = [c1,c2,c3] + // here is the issue, world_2_pix is actually K^T + // M is thus (KWH)^T = (WH)^T * K^T = (WH)^T * world_2_pix + // thus M stores the "row majored" version of KWH, or column major version of (KWH)^T mat3 M = glm::transpose(WH) * world_2_pix; + /** + * =============================================== + * Compute AABB + * =============================================== + */ // compute AABB - const vec3 M0 = vec3(M[0][0], M[0][1], M[0][2]); - const vec3 M1 = vec3(M[1][0], M[1][1], M[1][2]); - const vec3 M2 = vec3(M[2][0], M[2][1], M[2][2]); + const vec3 M0 = vec3(M[0][0], M[0][1], M[0][2]); // the first column of KWH^T, thus first row of KWH + const vec3 M1 = vec3(M[1][0], M[1][1], M[1][2]); // the second column of KWH^T, thus second row of KWH + const vec3 M2 = vec3(M[2][0], M[2][1], M[2][2]); // the third column of KWH^T, thus third row of KWH + // we know that KWH brings [u,v,-1] to ray1, ray2, ray3] = [xz, yz, z] + // temp_point is [1,1,-1], which is a "corner" of the UV space. const vec3 temp_point = vec3(1.0f, 1.0f, -1.0f); + + // ============================================== + // trivial implementation to find mean and 1 sigma radius + // ============================================== + // const vec3 mean_ray = glm::transpose(M) * vec3(0.0f, 0.0f, -1.0f); + // const vec3 temp_point_ray = glm::transpose(M) * temp_point; + + // const vec2 mean2d = vec2(mean_ray.x / mean_ray.z, mean_ray.y / mean_ray.z); + // const vec2 half_extend_p = vec2(temp_point_ray.x / temp_point_ray.z, temp_point_ray.y / temp_point_ray.z) - mean2d; + // const vec2 half_extend = vec2(half_extend_p.x * half_extend_p.x, half_extend_p.y * half_extend_p.y); + + // ============================================== + // pro implementation + // ============================================== + // this is purely resulted from algebraic manipulation + // check here for details: https://github.com/hbb1/diff-surfel-rasterization/issues/8#issuecomment-2138069016 const T distance = sum(temp_point * M2 * M2); + // ill-conditioned primitives will have distance = 0.0f, we ignore them if (distance == 0.0f) return; @@ -103,6 +191,8 @@ __global__ void fully_fused_projection_fwd_2dgs_kernel( const vec2 temp = {sum(f * M0 * M0), sum(f * M1 * M1)}; const vec2 half_extend = mean2d * mean2d - temp; + + // ============================================== const T radius = ceil(3.f * sqrt(max(1e-4, max(half_extend.x, half_extend.y)))); @@ -111,15 +201,17 @@ __global__ void fully_fused_projection_fwd_2dgs_kernel( return; } + // CULLING STEP: // mask out gaussians outside the image region if (mean2d.x + radius <= 0 || mean2d.x - radius >= image_width || mean2d.y + radius <= 0 || mean2d.y - radius >= image_height) { - radii[idx] = 0; + radii[idx] = 0; return; } // normals dual visible vec3 normal = RS_camera[2]; + // flip normal if it is pointing away from the camera T multipler = glm::dot(-normal, mean_c) > 0 ? 1 : -1; normal *= multipler; @@ -128,6 +220,8 @@ __global__ void fully_fused_projection_fwd_2dgs_kernel( means2d[idx * 2] = mean2d.x; means2d[idx * 2 + 1] = mean2d.y; depths[idx] = mean_c.z; + + // row major storing (KWH) ray_transforms[idx * 9] = M0.x; ray_transforms[idx * 9 + 1] = M0.y; ray_transforms[idx * 9 + 2] = M0.z; @@ -137,6 +231,8 @@ __global__ void fully_fused_projection_fwd_2dgs_kernel( ray_transforms[idx * 9 + 6] = M2.x; ray_transforms[idx * 9 + 7] = M2.y; ray_transforms[idx * 9 + 8] = M2.z; + + // primitive normals normals[idx * 3] = normal.x; normals[idx * 3 + 1] = normal.y; normals[idx * 3 + 2] = normal.z; diff --git a/gsplat/cuda/csrc/rasterize_to_pixels_2dgs_bwd.cu b/gsplat/cuda/csrc/rasterize_to_pixels_2dgs_bwd.cu index ddf508283..fd77480bd 100644 --- a/gsplat/cuda/csrc/rasterize_to_pixels_2dgs_bwd.cu +++ b/gsplat/cuda/csrc/rasterize_to_pixels_2dgs_bwd.cu @@ -15,18 +15,20 @@ namespace cg = cooperative_groups; ****************************************************************************/ template __global__ void rasterize_to_pixels_bwd_2dgs_kernel( - const uint32_t C, - const uint32_t N, - const uint32_t n_isects, - const bool packed, + const uint32_t C, // number of cameras + const uint32_t N, // number of gaussians + const uint32_t n_isects, // number of ray-primitive intersections. + const bool packed, // whether the input tensors are packed // fwd inputs - const vec2 *__restrict__ means2d, // [C, N, 2] or [nnz, 2] - const S *__restrict__ ray_transforms, // [C, N, 3] or [nnz, 3] - const S *__restrict__ colors, // [C, N, COLOR_DIM] or [nnz, COLOR_DIM] - const S *__restrict__ normals, // [C, N, 3] or [nnz, 3] - const S *__restrict__ opacities, // [C, N] or [nnz] - const S *__restrict__ backgrounds, // [C, COLOR_DIM] or [nnz, COLOR_DIM] - const bool *__restrict__ masks, // [C, tile_height, tile_width] + const vec2 *__restrict__ means2d, // Projected Gaussian means. [C, N, 2] if packed is False, [nnz, 2] if packed is True. + const S *__restrict__ ray_transforms, // transformation matrices that transforms xy-planes in pixel spaces into splat coordinates. [C, N, 3, 3] if packed is False, [nnz, channels] if packed is True. + // This is (KWH)^{-1} in the paper (takes screen [x,y] and map to [u,v]) + const S *__restrict__ colors, // [C, N, COLOR_DIM] or [nnz, COLOR_DIM] // Gaussian colors or ND features. + const S *__restrict__ normals, // [C, N, 3] or [nnz, 3] // The normals in camera space. + const S *__restrict__ opacities, // [C, N] or [nnz] // Gaussian opacities that support per-view values. + const S *__restrict__ backgrounds, // [C, COLOR_DIM] // Background colors on camera basis + const bool *__restrict__ masks, // [C, tile_height, tile_width] // Optional tile mask to skip rendering GS to masked tiles. + const uint32_t image_width, const uint32_t image_height, const uint32_t tile_size, @@ -34,19 +36,22 @@ __global__ void rasterize_to_pixels_bwd_2dgs_kernel( const uint32_t tile_height, const int32_t *__restrict__ tile_offsets, // [C, tile_height, tile_width] const int32_t *__restrict__ flatten_ids, // [n_isects] + // fwd outputs const S *__restrict__ render_colors, // [C, image_height, image_width, // COLOR_DIM] const S *__restrict__ render_alphas, // [C, image_height, image_width, 1] - const int32_t *__restrict__ last_ids, // [C, image_height, image_width] - const int32_t *__restrict__ median_ids, // [C, image_height, image_width] + const int32_t *__restrict__ last_ids, // [C, image_height, image_width] // the id to last gaussian that got intersected + const int32_t *__restrict__ median_ids, // [C, image_height, image_width] // the id to the gaussian that brings the opacity over 0.5 + // grad outputs - const S *__restrict__ v_render_colors, // [C, image_height, image_width, + const S *__restrict__ v_render_colors, // [C, image_height, image_width, // RGB // COLOR_DIM] - const S *__restrict__ v_render_alphas, // [C, image_height, image_width, 1] - const S *__restrict__ v_render_normals, // [C, image_height, image_width, 3] - const S *__restrict__ v_render_distort, // [C, image_height, image_width, 1] - const S *__restrict__ v_render_median, // [C, image_height, image_width, 1] + const S *__restrict__ v_render_alphas, // [C, image_height, image_width, 1] // total opacities. + const S *__restrict__ v_render_normals, // [C, image_height, image_width, 3] // camera space normals + const S *__restrict__ v_render_distort, // [C, image_height, image_width, 1] // mip-nerf 360 distorts + const S *__restrict__ v_render_median, // [C, image_height, image_width, 1] // the median depth + // grad inputs vec2 *__restrict__ v_means2d_abs, // [C, N, 2] or [nnz, 2] vec2 *__restrict__ v_means2d, // [C, N, 2] or [nnz, 2] @@ -56,21 +61,30 @@ __global__ void rasterize_to_pixels_bwd_2dgs_kernel( S *__restrict__ v_normals, // [C, N, 3] or [nnz, 3] S *__restrict__ v_densify ) { + /** + * ============================== + * Set up the thread blocks + * blocks are assigned tilewise, and threads are assigned pixelwise + * ============================== + */ auto block = cg::this_thread_block(); uint32_t camera_id = block.group_index().x; - uint32_t tile_id = - block.group_index().y * tile_width + block.group_index().z; + uint32_t tile_id = block.group_index().y * tile_width + block.group_index().z; uint32_t i = block.group_index().y * tile_size + block.thread_index().y; uint32_t j = block.group_index().z * tile_size + block.thread_index().x; tile_offsets += camera_id * tile_height * tile_width; render_alphas += camera_id * image_height * image_width; + last_ids += camera_id * image_height * image_width; median_ids += camera_id * image_height * image_width; + v_render_colors += camera_id * image_height * image_width * COLOR_DIM; v_render_alphas += camera_id * image_height * image_width; v_render_normals += camera_id * image_height * image_width * 3; v_render_median += camera_id * image_height * image_width; + + if (backgrounds != nullptr) { backgrounds += camera_id * COLOR_DIM; } @@ -105,50 +119,69 @@ __global__ void rasterize_to_pixels_bwd_2dgs_kernel( ? n_isects : tile_offsets[tile_id + 1]; const uint32_t block_size = block.size(); - const uint32_t num_batches = - (range_end - range_start + block_size - 1) / block_size; + // number of batches needed to process all gaussians in this tile + const uint32_t num_batches = (range_end - range_start + block_size - 1) / block_size; + + /** + * ============================== + * Memory Allocation + * Memory is laid out as: + * | pix_x : pix_y : opac | u_M : v_M : w_M | rgb : normal | + * ============================== + */ extern __shared__ int s[]; int32_t *id_batch = (int32_t *)s; // [block_size] - vec3 *xy_opacity_batch = - reinterpret_cast *>(&id_batch[block_size]); // [block_size] - vec3 *u_Ms_batch = - reinterpret_cast *>(&xy_opacity_batch[block_size] - ); // [block_size] - vec3 *v_Ms_batch = - reinterpret_cast *>(&u_Ms_batch[block_size] - ); // [block_size] - vec3 *w_Ms_batch = - reinterpret_cast *>(&v_Ms_batch[block_size] - ); // [block_size] + + vec3 *xy_opacity_batch = reinterpret_cast *>(&id_batch[block_size]); // [block_size] + vec3 *u_Ms_batch = reinterpret_cast *>(&xy_opacity_batch[block_size]); // [block_size] + vec3 *v_Ms_batch = reinterpret_cast *>(&u_Ms_batch[block_size]); // [block_size] + vec3 *w_Ms_batch = reinterpret_cast *>(&v_Ms_batch[block_size]); // [block_size] + + // extended memory block S *rgbs_batch = (S *)&w_Ms_batch[block_size]; // [block_size * COLOR_DIM] S *normals_batch = &rgbs_batch[block_size * COLOR_DIM]; // [block_size * 3] // this is the T AFTER the last gaussian in this pixel S T_final = 1.0f - render_alphas[pix_id]; S T = T_final; + // the contribution from gaussians behind the current one + // this is used to compute d(alpha)/d(c_i) S buffer[COLOR_DIM] = {0.f}; S buffer_normals[3] = {0.f}; + // index of last gaussian to contribute to this pixel const int32_t bin_final = inside ? last_ids[pix_id] : 0; + // index of gaussian that contributes to median depth const int32_t median_idx = inside ? median_ids[pix_id] : 0; - // df/d_out for this pixel + /** + * ============================== + * Fetching Data + * ============================== + */ + + // df/d_out for this pixel (within register) + // FETCH COLOR GRADIENT S v_render_c[COLOR_DIM]; GSPLAT_PRAGMA_UNROLL for (uint32_t k = 0; k < COLOR_DIM; ++k) { v_render_c[k] = v_render_colors[pix_id * COLOR_DIM + k]; } + + // FETCH ALPHA GRADIENT const S v_render_a = v_render_alphas[pix_id]; S v_render_n[3]; + + // FETCH NORMAL GRADIENT (NORMALIZATION FOR 2DGS) GSPLAT_PRAGMA_UNROLL for (uint32_t k = 0; k < 3; ++k) { v_render_n[k] = v_render_normals[pix_id * 3 + k]; } - // prepare for distortion + // PREPARE FOR DISTORTION (IF DISTORSION LOSS ENABLED) S v_distort = 0.f; S accum_d, accum_w; S accum_d_buffer, accum_w_buffer, distort_buffer; @@ -169,8 +202,17 @@ __global__ void rasterize_to_pixels_bwd_2dgs_kernel( // each thread loads one gaussian at a time before rasterizing const uint32_t tr = block.thread_rank(); cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block); - const int32_t warp_bin_final = - cg::reduce(warp, bin_final, cg::greater()); + + // find the maximum final gaussian ids in the thread warp. + // this gives the last gaussian id that have intersected with any pixels in the warp + const int32_t warp_bin_final = cg::reduce(warp, bin_final, cg::greater()); + + /** + * ======================================================= + * Calculating Derivatives + * ======================================================= + */ + // loop over all batches of primitives for (uint32_t b = 0; b < num_batches; ++b) { // resync all threads before writing next batch of shared mem block.sync(); @@ -180,15 +222,28 @@ __global__ void rasterize_to_pixels_bwd_2dgs_kernel( // index of gaussian to load // batch end is the index of the last gaussian in the batch // These values can be negative so must be int32 instead of uint32 + + // loop factors: + // we start with loop end and interate backwards const int32_t batch_end = range_end - 1 - block_size * b; const int32_t batch_size = min(block_size, batch_end + 1 - range_start); + + // VERY IMPORTANT HERE! + // we are looping from back to front + // so we are processing the gaussians in the order of closest to furthest + // if you use symbolic solver on splatting rendering equations you will see const int32_t idx = batch_end - tr; + + /* + * Fetch Gaussian Primitives and STORE THEM IN REVERSE ORDER + */ if (idx >= range_start) { int32_t g = flatten_ids[idx]; // flatten index in [C * N] or [nnz] id_batch[tr] = g; const vec2 xy = means2d[g]; const S opac = opacities[g]; xy_opacity_batch[tr] = {xy.x, xy.y, opac}; + u_Ms_batch[tr] = { ray_transforms[g * 9], ray_transforms[g * 9 + 1], ray_transforms[g * 9 + 2] }; @@ -209,31 +264,52 @@ __global__ void rasterize_to_pixels_bwd_2dgs_kernel( } // wait for other threads to collect the gaussians in batch block.sync(); - // process gaussians in the current batch for this pixel + // loops through the gaussians in the current batch for this pixel // 0 index is the furthest back gaussian in the batch - for (uint32_t t = max(0, batch_end - warp_bin_final); t < batch_size; - ++t) { + /** + * ================================================== + * BACKWARD LOOPING THROUGH PRIMITIVES + * ================================================== + */ + for (uint32_t t = max(0, batch_end - warp_bin_final); t < batch_size; ++t) { + bool valid = inside; if (batch_end - t > bin_final) { valid = 0; } - S alpha; - S opac; - S vis; - S gauss_weight_3d; - S gauss_weight_2d; - S gauss_weight; - vec2 s; - vec2 d; - vec3 h_u; - vec3 h_v; - vec3 ray_cross; - vec3 w_M; + + /** + * ================================================== + * Forward pass variables + * ================================================== + */ + S alpha; // for the currently processed gaussian, per pixel + S opac; // opacity of the currently processed gaussian, per pixel + S vis; // visibility of the currently processed gaussian (the pure gaussian weight, not multiplied by opacity), per pixel + S gauss_weight_3d; // 3D gaussian weight (using the proper intersection of UV space), per pixel + S gauss_weight_2d; // 2D gaussian weight (using the projected 2D mean), per pixel + S gauss_weight; // minimum of 3D and 2D gaussian weights, per pixel + + vec2 s; // normalized point of intersection on the uv, per pixel + vec2 d; // position on uv plane with respect to the primitive center, per pixel + vec3 h_u; // homogeneous plane parameter for us, per pixel + vec3 h_v; // homogeneous plane parameter for vs, per pixel + vec3 ray_cross; // ray cross product, the ray of plane intersection, per pixel + vec3 w_M; // depth component of the ray transform matrix, per pixel + + /** + * ================================================== + * Run through the forward pass, but only for the t-th primitive + * ================================================== + */ if (valid) { vec3 xy_opac = xy_opacity_batch[t]; + opac = xy_opac.z; + const vec3 u_M = u_Ms_batch[t]; const vec3 v_M = v_Ms_batch[t]; + w_M = w_Ms_batch[t]; h_u = px * w_M - u_M; @@ -246,14 +322,20 @@ __global__ void rasterize_to_pixels_bwd_2dgs_kernel( valid = false; s = {ray_cross.x / ray_cross.z, ray_cross.y / ray_cross.z}; + // GAUSSIAN KERNEL EVALUATION gauss_weight_3d = s.x * s.x + s.y * s.y; d = {xy_opac.x - px, xy_opac.y - py}; + + // 2D gaussian weight using the projected 2D mean gauss_weight_2d = FILTER_INV_SQUARE * (d.x * d.x + d.y * d.y); gauss_weight = min(gauss_weight_3d, gauss_weight_2d); + // visibility and alpha const S sigma = 0.5f * gauss_weight; vis = __expf(-sigma); - alpha = min(0.999f, opac * vis); + alpha = min(0.999f, opac * vis); // clipped alpha + + // gaussian throw out if (sigma < 0.f || alpha < 1.f / 255.f) { valid = false; } @@ -263,53 +345,100 @@ __global__ void rasterize_to_pixels_bwd_2dgs_kernel( if (!warp.any(valid)) { continue; } + + /** + * ================================================== + * Gradient variables + * + * note: the "local" suffix means this is the gradient of a primitive for a pixel + * in the end of the loops, we will reduce sum over all threads in the block to get the final gradient + * ================================================== + */ + // rgb gradients S v_rgb_local[COLOR_DIM] = {0.f}; + // normal gradients S v_normal_local[3] = {0.f}; + + // ray transform gradients vec3 v_u_M_local = {0.f, 0.f, 0.f}; vec3 v_v_M_local = {0.f, 0.f, 0.f}; vec3 v_w_M_local = {0.f, 0.f, 0.f}; + + // 2D mean gradients, used if 2d gaussian weight is applied vec2 v_xy_local = {0.f, 0.f}; + + // absolute 2D mean gradients, used if 2d gaussian weight is applied vec2 v_xy_abs_local = {0.f, 0.f}; + + // opacity gradients S v_opacity_local = 0.f; + // initialize everything to 0, only set if the lane is valid + /** + * ================================================== + * Calculating Derivatives w.r.t current primitive / gaussian + * ================================================== + */ if (valid) { + // gradient contribution from median depth if (batch_end - t == median_idx) { + // v_median is a special gradient input from forward pass + // not yet clear what this is for v_rgb_local[COLOR_DIM - 1] += v_median; } + /** + * d(img)/d(rgb) and d(img)/d(alpha) + */ + // compute the current T for this gaussian + // since the output T = coprod (1 - alpha_i), we have T_(i-1) = T_i * 1/(1 - alpha_(i-1)) + // potential numerical stability issue if alpha -> 1 S ra = 1.0f / (1.0f - alpha); T *= ra; + // update v_rgb for this gaussian + // because the weight is computed as: c_i (a_i G_i) * T : T = prod{1, i-1}(1 - a_j G_j) + // we have d(img)/d(c_i) = (a_i G_i) * T + // where alpha_i is a_i * G_i const S fac = alpha * T; GSPLAT_PRAGMA_UNROLL for (uint32_t k = 0; k < COLOR_DIM; ++k) { v_rgb_local[k] += fac * v_render_c[k]; } - // contribution from this pixel + + // contribution from this pixel to alpha + // we have d(alpha)/d(c_i) = c_i * G_i * T + [grad contribution from following gaussians in T term] + // this can be proven by symbolic differentiation of a_i with respect to c_out S v_alpha = 0.f; for (uint32_t k = 0; k < COLOR_DIM; ++k) { - v_alpha += - (rgbs_batch[t * COLOR_DIM + k] * T - buffer[k] * ra) * - v_render_c[k]; + v_alpha += (rgbs_batch[t * COLOR_DIM + k] * T - buffer[k] * ra) * v_render_c[k]; } + /* + * d(normal_out) / d(rgb) and d(normal_out) / d(alpha) + */ + // update v_normal for this gaussian GSPLAT_PRAGMA_UNROLL for (uint32_t k = 0; k < 3; ++k) { v_normal_local[k] = fac * v_render_n[k]; } + for (uint32_t k = 0; k < 3; ++k) { - v_alpha += (normals_batch[t * 3 + k] * T - - buffer_normals[k] * ra) * - v_render_n[k]; + v_alpha += (normals_batch[t * 3 + k] * T - buffer_normals[k] * ra) * v_render_n[k]; } + /* + * d(alpha_out) / d(alpha) + */ v_alpha += T_final * ra * v_render_a; - // contribution from background pixel + // adjust the alpha gradients by background color + // this prevents the background rendered in the fwd pass being considered as inaccuracies in primitives + // this allows us to swtich background colors to prevent overfitting to particular backgrounds i.e. black if (backgrounds != nullptr) { S accum = 0.f; GSPLAT_PRAGMA_UNROLL @@ -338,15 +467,26 @@ __global__ void rasterize_to_pixels_bwd_2dgs_kernel( v_distort; } - //====== 2DGS ======// + /** ================================================== + * 2DGS backward pass: compute gradients of d_out / d_G_i and d_G_i w.r.t geometry parameters + * ================================================== + */ if (opac * vis <= 0.999f) { S v_depth = 0.f; + // d(a_i * G_i) / d(G_i) = a_i const S v_G = opac * v_alpha; + + // case 1: in the forward pass, the proper ray-primitive intersection is used if (gauss_weight_3d <= gauss_weight_2d) { + + // derivative of G_i w.r.t. ray-primitive intersection uv coordinates const vec2 v_s = { v_G * -vis * s.x + v_depth * w_M.x, v_G * -vis * s.y + v_depth * w_M.y }; + + // backward through the projective transform + // @see rasterize_to_pixels_2dgs_fwd.cu to understand what is going on here const vec3 v_z_w_M = {s.x, s.y, 1.0}; const S v_sx_pz = v_s.x / ray_cross.z; const S v_sy_pz = v_s.y / ray_cross.z; @@ -356,6 +496,7 @@ __global__ void rasterize_to_pixels_bwd_2dgs_kernel( const vec3 v_h_u = glm::cross(h_v, v_ray_cross); const vec3 v_h_v = glm::cross(v_ray_cross, h_u); + // derivative of ray-primitive intersection uv coordinates w.r.t. transformation (geometry) coefficients v_u_M_local = {-v_h_u.x, -v_h_u.y, -v_h_u.z}; v_v_M_local = {-v_h_v.x, -v_h_v.y, -v_h_v.z}; v_w_M_local = { @@ -364,7 +505,9 @@ __global__ void rasterize_to_pixels_bwd_2dgs_kernel( px * v_h_u.z + py * v_h_v.z + v_depth * v_z_w_M.z }; + // case 2: in the forward pass, the 2D gaussian projected gaussian weight is used } else { + // computing the derivative of G_i w.r.t. 2d projected gaussian parameters (trivial) const S v_G_ddelx = -vis * FILTER_INV_SQUARE * d.x; const S v_G_ddely = -vis * FILTER_INV_SQUARE * d.y; v_xy_local = {v_G * v_G_ddelx, v_G * v_G_ddely}; @@ -377,16 +520,28 @@ __global__ void rasterize_to_pixels_bwd_2dgs_kernel( v_opacity_local = vis * v_alpha; } + /** + * Update the cumulative "later" gaussian contributions, used in derivatives of render with respect to alphas + */ GSPLAT_PRAGMA_UNROLL for (uint32_t k = 0; k < COLOR_DIM; ++k) { buffer[k] += rgbs_batch[t * COLOR_DIM + k] * fac; } + /** + * Update the cumulative "later" gaussian contributions, used in derivatives of output normals w.r.t. alphas + */ GSPLAT_PRAGMA_UNROLL for (uint32_t k = 0; k < 3; ++k) { buffer_normals[k] += normals_batch[t * 3 + k] * fac; } } + + /** + * ================================================== + * Warp-level reduction to compute the sum of the gradients for each gaussian + * ================================================== + */ warpSum(v_rgb_local, warp); warpSum<3, S>(v_normal_local, warp); warpSum(v_xy_local, warp); @@ -398,6 +553,12 @@ __global__ void rasterize_to_pixels_bwd_2dgs_kernel( } warpSum(v_opacity_local, warp); int32_t g = id_batch[t]; // flatten index in [C * N] or [nnz] + + /** + * ================================================== + * Write the gradients to the global memory + * ================================================== + */ if (warp.thread_rank() == 0) { S *v_rgb_ptr = (S *)(v_colors) + COLOR_DIM * g; GSPLAT_PRAGMA_UNROLL diff --git a/gsplat/cuda/csrc/rasterize_to_pixels_2dgs_fwd.cu b/gsplat/cuda/csrc/rasterize_to_pixels_2dgs_fwd.cu index f3705f5f9..839a72f8b 100644 --- a/gsplat/cuda/csrc/rasterize_to_pixels_2dgs_fwd.cu +++ b/gsplat/cuda/csrc/rasterize_to_pixels_2dgs_fwd.cu @@ -14,48 +14,66 @@ namespace cg = cooperative_groups; * Rasterization to Pixels Forward Pass 2DGS ****************************************************************************/ + /** + * + */ template __global__ void rasterize_to_pixels_fwd_2dgs_kernel( - const uint32_t C, - const uint32_t N, - const uint32_t n_isects, - const bool packed, - const vec2 *__restrict__ means2d, - const S *__restrict__ ray_transforms, - const S *__restrict__ colors, // [C, N, COLOR_DIM] or [nnz, COLOR_DIM] - const S *__restrict__ opacities, // [C, N] or [nnz] - const S *__restrict__ normals, // [C, N, 3] or [nnz, 3] - const S *__restrict__ backgrounds, // [C, COLOR_DIM] - const bool *__restrict__ masks, // [C, tile_height, tile_width] + const uint32_t C, // number of cameras + const uint32_t N, // number of gaussians + const uint32_t n_isects, // number of ray-primitive intersections. + const bool packed, // whether the input tensors are packed + const vec2 *__restrict__ means2d, // Projected Gaussian means. [C, N, 2] if packed is False, [nnz, 2] if packed is True. + const S *__restrict__ ray_transforms, // transformation matrices that transforms xy-planes in pixel spaces into splat coordinates. [C, N, 3, 3] if packed is False, [nnz, channels] if packed is True. + // This is (KWH)^{-1} in the paper (takes screen [x,y] and map to [u,v]) + const S *__restrict__ colors, // [C, N, COLOR_DIM] or [nnz, COLOR_DIM] // Gaussian colors or ND features. + const S *__restrict__ opacities, // [C, N] or [nnz] // Gaussian opacities that support per-view values. + const S *__restrict__ normals, // [C, N, 3] or [nnz, 3] // The normals in camera space. + const S *__restrict__ backgrounds, // [C, COLOR_DIM] // Background colors on camera basis + const bool *__restrict__ masks, // [C, tile_height, tile_width] // Optional tile mask to skip rendering GS to masked tiles. const uint32_t image_width, const uint32_t image_height, const uint32_t tile_size, const uint32_t tile_width, const uint32_t tile_height, - const int32_t *__restrict__ tile_offsets, // [C, tile_height, tile_width] - const int32_t *__restrict__ flatten_ids, // [n_isects] + const int32_t *__restrict__ tile_offsets, // [C, tile_height, tile_width] // Intersection offsets outputs from `isect_offset_encode()`, this is the result of a prefix sum, and + // gives the interval that our gaussians are gonna use. + const int32_t *__restrict__ flatten_ids, // [n_isects] // The global flatten indices in [C * N] or [nnz] from `isect_tiles()`. + + + // outputs S *__restrict__ render_colors, // [C, image_height, image_width, COLOR_DIM] S *__restrict__ render_alphas, // [C, image_height, image_width, 1] S *__restrict__ render_normals, // [C, image_height, image_width, 3] - S *__restrict__ render_distort, // [C, image_height, image_width, 1] - S *__restrict__ render_median, // [C, image_height, image_width, 1] - int32_t *__restrict__ last_ids, // [C, image_height, image_width] - int32_t *__restrict__ median_ids // [C, image_height, image_width] + S *__restrict__ render_distort, // [C, image_height, image_width, 1] // Stores the per-pixel distortion error proposed in Mip-NeRF 360. + S *__restrict__ render_median, // [C, image_height, image_width, 1] // Stores the median depth contribution for each pixel "set to the depth of the Gaussian that brings the accumulated opacity over 0.5." + int32_t *__restrict__ last_ids, // [C, image_height, image_width] // Stores the index of the last Gaussian that contributed to each pixel. + int32_t *__restrict__ median_ids // [C, image_height, image_width] // Stores the index of the Gaussian that contributes to the median depth for each pixel (bring over 0.5). ) { // each thread draws one pixel, but also timeshares caching gaussians in a // shared tile + /** + * ============================== + * Thread and block setup: + * This sets up the thread and block indices, determining which camera, tile, and pixel each thread will process. + * The grid structure is assigend as: + * C * tile_height * tile_width blocks (3d grid), each block is a tile. + * Each thread is responsible for one pixel. (blockSize = tile_size * tile_size) + * ============================== + */ auto block = cg::this_thread_block(); int32_t camera_id = block.group_index().x; - int32_t tile_id = - block.group_index().y * tile_width + block.group_index().z; + int32_t tile_id = block.group_index().y * tile_width + block.group_index().z; uint32_t i = block.group_index().y * tile_size + block.thread_index().y; uint32_t j = block.group_index().z * tile_size + block.thread_index().x; - tile_offsets += camera_id * tile_height * tile_width; - render_colors += camera_id * image_height * image_width * COLOR_DIM; - render_alphas += camera_id * image_height * image_width; - last_ids += camera_id * image_height * image_width; + tile_offsets += camera_id * tile_height * tile_width; // get the global offset of the tile w.r.t the camera + render_colors += camera_id * image_height * image_width * COLOR_DIM; // get the global offset of the pixel w.r.t the camera + render_alphas += camera_id * image_height * image_width; // get the global offset of the pixel w.r.t the camera + last_ids += camera_id * image_height * image_width; // get the global offset of the pixel w.r.t the camera + + // get the global offset of the background and mask if (backgrounds != nullptr) { backgrounds += camera_id * COLOR_DIM; } @@ -63,6 +81,7 @@ __global__ void rasterize_to_pixels_fwd_2dgs_kernel( masks += camera_id * tile_height * tile_width; } + // find the center of the pixel S px = (S)j + 0.5f; S py = (S)i + 0.5f; int32_t pix_id = i * image_width + j; @@ -85,33 +104,46 @@ __global__ void rasterize_to_pixels_fwd_2dgs_kernel( // have all threads in tile process the same gaussians in batches // first collect gaussians between range.x and range.y in batches // which gaussians to look through in this tile + + // print int32_t range_start = tile_offsets[tile_id]; int32_t range_end = + // see if this is the last tile in the camera (camera_id == C - 1) && (tile_id == tile_width * tile_height - 1) ? n_isects : tile_offsets[tile_id + 1]; const uint32_t block_size = block.size(); - uint32_t num_batches = - (range_end - range_start + block_size - 1) / block_size; - + uint32_t num_batches = (range_end - range_start + block_size - 1) / block_size; + + + /** + * ============================== + * Register computing variables: + * For each pixel, we need to find its uv intersection with the gaussian primitives. + * then we retrieve the kernel's parameters and kernel weights + * do the splatting rendering equation. + * ============================== + */ + // Shared memory layout: + // This memory is laid out as follows: + // | gaussian indices | x : y : alpha | u | v | w | extern __shared__ int s[]; int32_t *id_batch = (int32_t *)s; // [block_size] + + // stores the concatination for projected primitive source (x, y) and opacity alpha vec3 *xy_opacity_batch = reinterpret_cast *>(&id_batch[block_size]); // [block_size] - vec3 *u_Ms_batch = - reinterpret_cast *>(&xy_opacity_batch[block_size] - ); // [block_size] - vec3 *v_Ms_batch = - reinterpret_cast *>(&u_Ms_batch[block_size] - ); // [block_size] - vec3 *w_Ms_batch = - reinterpret_cast *>(&v_Ms_batch[block_size] - ); // [block_size] + + // these are row vectors of the ray transformation matrices for the current batch of gaussians + vec3 *u_Ms_batch = reinterpret_cast *>(&xy_opacity_batch[block_size]); // [block_size] + vec3 *v_Ms_batch = reinterpret_cast *>(&u_Ms_batch[block_size]); // [block_size] + vec3 *w_Ms_batch = reinterpret_cast *>(&v_Ms_batch[block_size]); // [block_size] // current visibility left to render // transmittance is gonna be used in the backward pass which requires a high // numerical precision so we use double for it. However double make bwd 1.5x // slower so we stick with float for now. + // The coefficient for volumetric rendering for our responsible pixel. S T = 1.0f; // index of most recent gaussian to write to this thread's pixel uint32_t cur_idx = 0; @@ -119,7 +151,7 @@ __global__ void rasterize_to_pixels_fwd_2dgs_kernel( // collect and process batches of gaussians // each thread loads one gaussian at a time before rasterizing its // designated pixel - uint32_t tr = block.thread_rank(); + uint32_t tr = block.thread_rank(); // Per-pixel distortion error proposed in Mip-NeRF 360. // Implemented reference: @@ -131,6 +163,14 @@ __global__ void rasterize_to_pixels_fwd_2dgs_kernel( S median_depth = 0.f; uint32_t median_idx = 0.f; + /** + * ============================== + * Per-pixel rendering: (2DGS Differntiable Rasterizer Forward Pass) + * This section is responsible for rendering a single pixel. + * It processes batches of gaussians and accumulates the pixel color and normal. + * ============================== + */ + // TODO (WZ): merge pix_out and normal_out to // S pix_out[COLOR_DIM + 3] = {0.f} S pix_out[COLOR_DIM] = {0.f}; @@ -146,6 +186,11 @@ __global__ void rasterize_to_pixels_fwd_2dgs_kernel( // index of gaussian to load uint32_t batch_start = range_start + block_size * b; uint32_t idx = batch_start + tr; + + // only threads within the range of the tile will fetch gaussians + /** + * Launch this block with each thread responsible for one gaussian. + */ if (idx < range_end) { int32_t g = flatten_ids[idx]; // flatten index in [C * N] or [nnz] id_batch[tr] = g; @@ -166,6 +211,50 @@ __global__ void rasterize_to_pixels_fwd_2dgs_kernel( // wait for other threads to collect the gaussians in batch block.sync(); + /** + * ================================================== + * Forward rasterization pass: + * ================================================== + * + * GSplat computes rasterization point of intersection as: + * 1. Generate 2 homogeneous plane parameter vectors as sets of points in UV space + * 2. Find the set of points that satisfy both conditions with the cross product + * 3. Find where this solution set intersects with UV plane using projective flattening + * + * For each gaussian G_i and pixel q_xy: + * + * 1. Compute homogeneous plane parameters: + * h_u = p_x * M_w - M_u + * h_v = p_y * M_w - M_v + * where M_u, M_v, M_w are rows of the KWH transform + * + * Note: this works because: + * for any vector q_uv [u, v, 1], applying co-vector h_u will yield the following expression: + * h_u * [u, v, 1]^T = P_x * (M_w * q_uv) - M_u * q_uv + * = P_x * q_ray.z - q_ray.x * q_ray.z + * - where P_x is the x-coordinate of the ray origin + * Thus: h_u defines a set of q_uv where q_uv's projected x coordinate in ray space is P_x + * which aligns with the homogeneous plane definition in original 2DGS paper (similar for h_v) + * + * 2. Compute intersection: + * zeta = h_u × h_v + * This cross product is the only solution that satisfies both homogeneous plane equations (dot product == 0) + * + * 3. Project to UV space: + * s_uv = [zeta_1/zeta_3, zeta_2/zeta_3] + * - since UV space is essentially another ray space, and arbitrary scale of q_uv will not change the result of dot product over orthogonality + * - thus, the result is the point of intersection in UV space + * + * 4. Evaluate gaussian kernel: + * G_i = exp(-(s_u^2 + s_v^2)/2) + * + * 5. Accumulate color: + * p_xy += alpha_i * c_i * G_i * prod(1 - alpha_j * G_j) + * + * This method efficiently computes the point of intersection and + * evaluates the gaussian kernel in UV space. + * Note: in some cases, we use the minimum of ray-intersection kernels and 2D projected gaussian kernels + */ // process gaussians in the current batch for this pixel uint32_t batch_size = min(block_size, range_end - batch_start); for (uint32_t t = 0; (t < batch_size) && !done; ++t) { @@ -177,6 +266,7 @@ __global__ void rasterize_to_pixels_fwd_2dgs_kernel( const vec3 v_M = v_Ms_batch[t]; const vec3 w_M = w_Ms_batch[t]; + // h_u and h_v are the homogeneous plane representations (they are contravariant to the points on the primitive plane) const vec3 h_u = px * w_M - u_M; const vec3 h_v = py * w_M - v_M; @@ -184,17 +274,27 @@ __global__ void rasterize_to_pixels_fwd_2dgs_kernel( if (ray_cross.z == 0.0) continue; - const vec2 s = - vec2(ray_cross.x / ray_cross.z, ray_cross.y / ray_cross.z); + const vec2 s = vec2(ray_cross.x / ray_cross.z, ray_cross.y / ray_cross.z); + + // IMPORTANT: This is where the gaussian kernel is evaluated!!!!! + // point of interseciton in uv space const S gauss_weight_3d = s.x * s.x + s.y * s.y; + + // projected gaussian kernel const vec2 d = {xy_opac.x - px, xy_opac.y - py}; - const S gauss_weight_2d = - FILTER_INV_SQUARE * (d.x * d.x + d.y * d.y); + // #define FILTER_INV_SQUARE 2.0f + const S gauss_weight_2d = FILTER_INV_SQUARE * (d.x * d.x + d.y * d.y); + + // merge ray-intersection kernel and 2d gaussian kernel const S gauss_weight = min(gauss_weight_3d, gauss_weight_2d); + const S sigma = 0.5f * gauss_weight; + // evaluation of the gaussian exponential term S alpha = min(0.999f, opac * __expf(-sigma)); + + // ignore transparent gaussians if (sigma < 0.f || alpha < 1.f / 255.f) { continue; } @@ -205,6 +305,7 @@ __global__ void rasterize_to_pixels_fwd_2dgs_kernel( break; } + // run volumetric rendering.. int32_t g = id_batch[t]; const S vis = alpha * T; const S *c_ptr = colors + g * COLOR_DIM; @@ -324,6 +425,7 @@ call_kernel_with_dim( // Each block covers a tile on the image. In total there are // C * tile_height * tile_width blocks. + // we assign one pixel to one thread. dim3 threads = {tile_size, tile_size, 1}; dim3 blocks = {C, tile_height, tile_width};