From 2958850c34bb56d5edf749a64ee603ce7c3bd728 Mon Sep 17 00:00:00 2001 From: Francis Williams Date: Wed, 10 Jul 2024 15:39:10 -0400 Subject: [PATCH 1/2] rasterize_to_indices_in_range --- .../csrc/rasterize_to_indices_in_range.cu | 78 ++++++++++--------- 1 file changed, 43 insertions(+), 35 deletions(-) diff --git a/gsplat/cuda/csrc/rasterize_to_indices_in_range.cu b/gsplat/cuda/csrc/rasterize_to_indices_in_range.cu index 227222537..515975437 100644 --- a/gsplat/cuda/csrc/rasterize_to_indices_in_range.cu +++ b/gsplat/cuda/csrc/rasterize_to_indices_in_range.cu @@ -5,6 +5,8 @@ #include #include +#include + namespace cg = cooperative_groups; /**************************************************************************** @@ -31,6 +33,9 @@ __global__ void rasterize_to_indices_in_range_kernel( // each thread draws one pixel, but also timeshares caching gaussians in a // shared tile + // For now we'll upcast float16 and bfloat16 to float32 + using OpT = typename OpType::type; + auto block = cg::this_thread_block(); uint32_t camera_id = block.group_index().x; uint32_t tile_id = block.group_index().y * tile_width + block.group_index().z; @@ -41,8 +46,8 @@ __global__ void rasterize_to_indices_in_range_kernel( tile_offsets += camera_id * tile_height * tile_width; transmittances += camera_id * image_height * image_width; - T px = (T)j + 0.5f; - T py = (T)i + 0.5f; + OpT px = (OpT)j + 0.5f; + OpT py = (OpT)i + 0.5f; int32_t pix_id = i * image_width + j; // return if out of bounds @@ -77,16 +82,16 @@ __global__ void rasterize_to_indices_in_range_kernel( extern __shared__ int s[]; int32_t *id_batch = (int32_t *)s; // [block_size] - vec3 *xy_opacity_batch = + vec3 *xy_opacity_batch = reinterpret_cast *>(&id_batch[block_size]); // [block_size] - vec3 *conic_batch = + vec3 *conic_batch = reinterpret_cast *>(&xy_opacity_batch[block_size]); // [block_size] // current visibility left to render // transmittance is gonna be used in the backward pass which requires a high // numerical precision so we (should) use double for it. However double make // bwd 1.5x slower so we stick with float for now. - T trans, next_trans; + OpT trans, next_trans; if (inside) { trans = transmittances[pix_id]; next_trans = trans; @@ -112,10 +117,10 @@ __global__ void rasterize_to_indices_in_range_kernel( if (idx < isect_range_end) { int32_t g = flatten_ids[idx]; id_batch[tr] = g; - const vec2 xy = means2d[g]; - const T opac = opacities[g]; - xy_opacity_batch[tr] = {xy.x, xy.y, opac}; - conic_batch[tr] = conics[g]; + const vec2 xy = means2d[g]; + const OpT opac = opacities[g]; + xy_opacity_batch[tr] = vec3(xy.x, xy.y, opac); + conic_batch[tr] = vec3(conics[g]); } // wait for other threads to collect the gaussians in batch @@ -124,14 +129,14 @@ __global__ void rasterize_to_indices_in_range_kernel( // process gaussians in the current batch for this pixel uint32_t batch_size = min(block_size, isect_range_end - batch_start); for (uint32_t t = 0; (t < batch_size) && !done; ++t) { - const vec3 conic = conic_batch[t]; - const vec3 xy_opac = xy_opacity_batch[t]; - const T opac = xy_opac.z; - const vec2 delta = {xy_opac.x - px, xy_opac.y - py}; - const T sigma = + const vec3 conic = vec3(conic_batch[t]); + const vec3 xy_opac = vec3(xy_opacity_batch[t]); + const OpT opac = xy_opac.z; + const vec2 delta = {xy_opac.x - px, xy_opac.y - py}; + const OpT sigma = 0.5f * (conic.x * delta.x * delta.x + conic.z * delta.y * delta.y) + conic.y * delta.x * delta.y; - T alpha = min(0.999f, opac * __expf(-sigma)); + OpT alpha = min(0.999f, opac * c10::cuda::compat::exp(-sigma)); if (sigma < 0.f || alpha < 1.f / 255.f) { continue; @@ -214,16 +219,17 @@ std::tuple rasterize_to_indices_in_range_tensor( if (n_isects) { torch::Tensor chunk_cnts = torch::zeros({C * image_height * image_width}, means2d.options().dtype(torch::kInt32)); - rasterize_to_indices_in_range_kernel - <<>>( - range_start, range_end, C, N, n_isects, - reinterpret_cast *>(means2d.data_ptr()), - reinterpret_cast *>(conics.data_ptr()), - opacities.data_ptr(), image_width, image_height, tile_size, - tile_width, tile_height, tile_offsets.data_ptr(), - flatten_ids.data_ptr(), transmittances.data_ptr(), - nullptr, chunk_cnts.data_ptr(), nullptr, nullptr); - + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, means2d.scalar_type(), "rasterize_to_indices_in_range", [&]() { + rasterize_to_indices_in_range_kernel + <<>>( + range_start, range_end, C, N, n_isects, + reinterpret_cast *>(means2d.data_ptr()), + reinterpret_cast *>(conics.data_ptr()), + opacities.data_ptr(), image_width, image_height, tile_size, + tile_width, tile_height, tile_offsets.data_ptr(), + flatten_ids.data_ptr(), transmittances.data_ptr(), + nullptr, chunk_cnts.data_ptr(), nullptr, nullptr); + }); torch::Tensor cumsum = torch::cumsum(chunk_cnts, 0, chunk_cnts.scalar_type()); n_elems = cumsum[-1].item(); chunk_starts = cumsum - chunk_cnts; @@ -237,16 +243,18 @@ std::tuple rasterize_to_indices_in_range_tensor( torch::Tensor pixel_ids = torch::empty({n_elems}, means2d.options().dtype(torch::kInt64)); if (n_elems) { - rasterize_to_indices_in_range_kernel - <<>>( - range_start, range_end, C, N, n_isects, - reinterpret_cast *>(means2d.data_ptr()), - reinterpret_cast *>(conics.data_ptr()), - opacities.data_ptr(), image_width, image_height, tile_size, - tile_width, tile_height, tile_offsets.data_ptr(), - flatten_ids.data_ptr(), transmittances.data_ptr(), - chunk_starts.data_ptr(), nullptr, - gaussian_ids.data_ptr(), pixel_ids.data_ptr()); + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, means2d.scalar_type(), "rasterize_to_indices_in_range", [&]() { + rasterize_to_indices_in_range_kernel + <<>>( + range_start, range_end, C, N, n_isects, + reinterpret_cast *>(means2d.data_ptr()), + reinterpret_cast *>(conics.data_ptr()), + opacities.data_ptr(), image_width, image_height, tile_size, + tile_width, tile_height, tile_offsets.data_ptr(), + flatten_ids.data_ptr(), transmittances.data_ptr(), + chunk_starts.data_ptr(), nullptr, + gaussian_ids.data_ptr(), pixel_ids.data_ptr()); + }); } return std::make_tuple(gaussian_ids, pixel_ids); } From ee33dd66d26e435fe58c1f9caa92cbb81f243ac3 Mon Sep 17 00:00:00 2001 From: Francis Williams Date: Thu, 11 Jul 2024 14:45:03 -0400 Subject: [PATCH 2/2] fix bugs --- gsplat/cuda/csrc/quat_scale_to_covar_preci_bwd.cu | 8 ++++---- gsplat/cuda/csrc/quat_scale_to_covar_preci_fwd.cu | 6 +++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/gsplat/cuda/csrc/quat_scale_to_covar_preci_bwd.cu b/gsplat/cuda/csrc/quat_scale_to_covar_preci_bwd.cu index cf601c0dc..9d537855e 100644 --- a/gsplat/cuda/csrc/quat_scale_to_covar_preci_bwd.cu +++ b/gsplat/cuda/csrc/quat_scale_to_covar_preci_bwd.cu @@ -42,8 +42,8 @@ __global__ void quat_scale_to_covar_preci_bwd_kernel( v_scales += idx * 3; v_quats += idx * 4; - vec4 quat = glm::make_vec4(quats + idx * 4); - vec3 scale = glm::make_vec3(scales + idx * 3); + vec4 quat = vec4(glm::make_vec4(quats + idx * 4)); + vec3 scale = vec3(glm::make_vec3(scales + idx * 3)); mat3 rotmat = quat_to_rotmat(quat); vec4 v_quat(0.f); @@ -58,7 +58,7 @@ __global__ void quat_scale_to_covar_preci_bwd_kernel( v_covars[2] * .5f, v_covars[4] * .5f, v_covars[5]); } else { v_covars += idx * 9; - mat3 v_covar_cast = glm::make_mat3(v_covars); + mat3 v_covar_cast = mat3(glm::make_mat3(v_covars)); v_covar = glm::transpose(v_covar_cast); } quat_scale_to_covar_vjp(quat, scale, rotmat, v_covar, v_quat, v_scale); @@ -73,7 +73,7 @@ __global__ void quat_scale_to_covar_preci_bwd_kernel( v_precis[2] * .5f, v_precis[4] * .5f, v_precis[5]); } else { v_precis += idx * 9; - mat3 v_precis_cast = glm::make_mat3(v_precis); + mat3 v_precis_cast = mat3(glm::make_mat3(v_precis)); v_preci = glm::transpose(v_precis_cast); } quat_scale_to_preci_vjp(quat, scale, rotmat, v_preci, v_quat, v_scale); diff --git a/gsplat/cuda/csrc/quat_scale_to_covar_preci_fwd.cu b/gsplat/cuda/csrc/quat_scale_to_covar_preci_fwd.cu index 0b74ec903..00c87771a 100644 --- a/gsplat/cuda/csrc/quat_scale_to_covar_preci_fwd.cu +++ b/gsplat/cuda/csrc/quat_scale_to_covar_preci_fwd.cu @@ -121,9 +121,9 @@ quat_scale_to_covar_preci_fwd_tensor(const torch::Tensor &quats, // [N, 4] AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, quats.scalar_type(), "quat_scale_to_covar_preci_fwd", [&]() { quat_scale_to_covar_preci_fwd_kernel<<<(N + N_THREADS - 1) / N_THREADS, N_THREADS, 0, stream>>>( - N, quats.data_ptr(), scales.data_ptr(), triu, - compute_covar ? covars.data_ptr() : nullptr, - compute_preci ? precis.data_ptr() : nullptr); + N, quats.data_ptr(), scales.data_ptr(), triu, + compute_covar ? covars.data_ptr() : nullptr, + compute_preci ? precis.data_ptr() : nullptr); }); } return std::make_tuple(covars, precis);