diff --git a/gsplat/_torch_impl.py b/gsplat/_torch_impl.py index e897eeb73..2724796be 100644 --- a/gsplat/_torch_impl.py +++ b/gsplat/_torch_impl.py @@ -186,7 +186,19 @@ def project_cov3d_ewa( cov2d[..., 1, 1] = cov2d[..., 1, 1] + 0.3 det_blur = cov2d[..., 0, 0] * cov2d[..., 1, 1] - cov2d[..., 0, 1] * cov2d[..., 0, 1] compensation = torch.sqrt(torch.clamp(det_orig / det_blur, min=0)) - return cov2d[..., :2, :2], compensation.detach() + return cov2d[..., :2, :2], compensation + + +def compute_compensation(cov2d_mat: Tensor): + """ + params: cov2d matrix (*, 2, 2) + returns: compensation factor as calculated in project_cov3d_ewa + """ + det_denom = cov2d_mat[..., 0, 0] * cov2d_mat[..., 1, 1] - cov2d_mat[..., 0, 1] ** 2 + det_nomin = (cov2d_mat[..., 0, 0] - 0.3) * (cov2d_mat[..., 1, 1] - 0.3) - cov2d_mat[ + ..., 0, 1 + ] ** 2 + return torch.sqrt(torch.clamp(det_nomin / det_denom, min=0)) def compute_cov2d_bounds(cov2d_mat: Tensor): diff --git a/gsplat/cuda/csrc/backward.cu b/gsplat/cuda/csrc/backward.cu index c6e7b3c86..bf71c1d13 100644 --- a/gsplat/cuda/csrc/backward.cu +++ b/gsplat/cuda/csrc/backward.cu @@ -331,9 +331,11 @@ __global__ void project_gaussians_backward_kernel( const float* __restrict__ cov3d, const int* __restrict__ radii, const float3* __restrict__ conics, + const float* __restrict__ compensation, const float2* __restrict__ v_xy, const float* __restrict__ v_depth, const float3* __restrict__ v_conic, + const float* __restrict__ v_compensation, float3* __restrict__ v_cov2d, float* __restrict__ v_cov3d, float3* __restrict__ v_mean3d, @@ -362,6 +364,7 @@ __global__ void project_gaussians_backward_kernel( // get v_cov2d cov2d_to_conic_vjp(conics[idx], v_conic[idx], v_cov2d[idx]); + cov2d_to_compensation_vjp(compensation[idx], conics[idx], v_compensation[idx], v_cov2d[idx]); // get v_cov3d (and v_mean3d contribution) project_cov3d_ewa_vjp( p_world, diff --git a/gsplat/cuda/csrc/backward.cuh b/gsplat/cuda/csrc/backward.cuh index f45d2d3a7..848841d6e 100644 --- a/gsplat/cuda/csrc/backward.cuh +++ b/gsplat/cuda/csrc/backward.cuh @@ -18,9 +18,11 @@ __global__ void project_gaussians_backward_kernel( const float* __restrict__ cov3d, const int* __restrict__ radii, const float3* __restrict__ conics, + const float* __restrict__ compensation, const float2* __restrict__ v_xy, const float* __restrict__ v_depth, const float3* __restrict__ v_conic, + const float* __restrict__ v_compensation, float3* __restrict__ v_cov2d, float* __restrict__ v_cov3d, float3* __restrict__ v_mean3d, diff --git a/gsplat/cuda/csrc/bindings.cu b/gsplat/cuda/csrc/bindings.cu index aa0879ba6..4b94eaf51 100644 --- a/gsplat/cuda/csrc/bindings.cu +++ b/gsplat/cuda/csrc/bindings.cu @@ -225,10 +225,12 @@ project_gaussians_backward_tensor( torch::Tensor &cov3d, torch::Tensor &radii, torch::Tensor &conics, + torch::Tensor &compensation, torch::Tensor &v_xy, torch::Tensor &v_depth, - torch::Tensor &v_conic -) { + torch::Tensor &v_conic, + torch::Tensor &v_compensation +){ DEVICE_GUARD(means3d); dim3 img_size_dim3; img_size_dim3.x = img_width; @@ -265,9 +267,11 @@ project_gaussians_backward_tensor( cov3d.contiguous().data_ptr(), radii.contiguous().data_ptr(), (float3 *)conics.contiguous().data_ptr(), + (float *)compensation.contiguous().data_ptr(), (float2 *)v_xy.contiguous().data_ptr(), v_depth.contiguous().data_ptr(), (float3 *)v_conic.contiguous().data_ptr(), + (float *)v_compensation.contiguous().data_ptr(), // Outputs. (float3 *)v_cov2d.contiguous().data_ptr(), v_cov3d.contiguous().data_ptr(), diff --git a/gsplat/cuda/csrc/bindings.h b/gsplat/cuda/csrc/bindings.h index f6e40a904..8aa3d4801 100644 --- a/gsplat/cuda/csrc/bindings.h +++ b/gsplat/cuda/csrc/bindings.h @@ -86,9 +86,11 @@ project_gaussians_backward_tensor( torch::Tensor &cov3d, torch::Tensor &radii, torch::Tensor &conics, + torch::Tensor &compensation, torch::Tensor &v_xy, torch::Tensor &v_depth, - torch::Tensor &v_conic + torch::Tensor &v_conic, + torch::Tensor &v_compensation ); diff --git a/gsplat/cuda/csrc/helpers.cuh b/gsplat/cuda/csrc/helpers.cuh index bbaff89d2..c07f5752a 100644 --- a/gsplat/cuda/csrc/helpers.cuh +++ b/gsplat/cuda/csrc/helpers.cuh @@ -82,6 +82,21 @@ inline __device__ void cov2d_to_conic_vjp( v_cov2d.z = v_Sigma[1][1]; } +inline __device__ void cov2d_to_compensation_vjp( + const float compensation, const float3 &conic, const float v_compensation, float3 &v_cov2d +) { + // comp = sqrt(det(cov2d - 0.3 I) / det(cov2d)) + // conic = inverse(cov2d) + // df / d_cov2d = df / d comp * 0.5 / comp * [ d comp^2 / d cov2d ] + // d comp^2 / d cov2d = (1 - comp^2) * conic - 0.3 I * det(conic) + float inv_det = conic.x * conic.z - conic.y * conic.y; + float one_minus_sqr_comp = 1 - compensation * compensation; + float v_sqr_comp = v_compensation * 0.5 / (compensation + 1e-6); + v_cov2d.x += v_sqr_comp * (one_minus_sqr_comp * conic.x - 0.3 * inv_det); + v_cov2d.y += 2 * v_sqr_comp * (one_minus_sqr_comp * conic.y); + v_cov2d.z += v_sqr_comp * (one_minus_sqr_comp * conic.z - 0.3 * inv_det); +} + // helper for applying R * p + T, expect mat to be ROW MAJOR inline __device__ float3 transform_4x3(const float *mat, const float3 p) { float3 out = { diff --git a/gsplat/project_gaussians.py b/gsplat/project_gaussians.py index d2c531322..3c57283f7 100644 --- a/gsplat/project_gaussians.py +++ b/gsplat/project_gaussians.py @@ -147,6 +147,7 @@ def forward( cov3d, radii, conics, + compensation, ) return (xys, depths, radii, conics, compensation, num_tiles_hit, cov3d) @@ -171,6 +172,7 @@ def backward( cov3d, radii, conics, + compensation, ) = ctx.saved_tensors (v_cov2d, v_cov3d, v_mean3d, v_scale, v_quat) = _C.project_gaussians_backward( @@ -190,9 +192,11 @@ def backward( cov3d, radii, conics, + compensation, v_xys, v_depths, v_conics, + v_compensation, ) # Return a gradient for each input. diff --git a/tests/test_project_gaussians.py b/tests/test_project_gaussians.py index d9507d984..3c8594150 100644 --- a/tests/test_project_gaussians.py +++ b/tests/test_project_gaussians.py @@ -163,7 +163,7 @@ def test_project_gaussians_backward(): depths, radii, conics, - _, + compensation, _, masks, ) = _torch_impl.project_gaussians_forward( @@ -184,6 +184,8 @@ def test_project_gaussians_backward(): v_xys = torch.randn_like(xys) v_depths = torch.randn_like(depths) v_conics = torch.randn_like(conics) + # compensation gradients + v_compensation = torch.randn_like(compensation) v_cov2d, v_cov3d, v_mean3d, v_scale, v_quat = _C.project_gaussians_backward( num_points, means3d, @@ -201,9 +203,11 @@ def test_project_gaussians_backward(): cov3d, radii, conics, + compensation, v_xys, v_depths, v_conics, + v_compensation, ) def scale_rot_to_cov3d_partial(scale, quat): @@ -243,6 +247,16 @@ def compute_cov2d_bounds_partial(cov2d): conic, _, _ = _torch_impl.compute_cov2d_bounds(cov2d_mat) return conic + def compute_compensation_partial(cov2d): + """ + cov2d (upper tri) (*, 3) -> compensation (*, 1) + """ + cov2d_mat = torch.zeros(*cov2d.shape[:-1], 2, 2, device=device) + i, j = torch.triu_indices(2, 2) + cov2d_mat[..., i, j] = cov2d + cov2d_mat[..., 1, 0] = cov2d[..., 1] + return _torch_impl.compute_compensation(cov2d_mat) + def project_pix_partial(mean3d): """ mean3d (*, 3) -> xy (*, 2) @@ -260,10 +274,12 @@ def compute_depth_partial(mean3d): _, vjp_scale_rot_to_cov3d = vjp(scale_rot_to_cov3d_partial, scales, quats) # type: ignore _, vjp_project_cov3d_ewa = vjp(project_cov3d_ewa_partial, means3d, cov3d) # type: ignore _, vjp_compute_cov2d_bounds = vjp(compute_cov2d_bounds_partial, cov2d) # type: ignore + _, vjp_compute_compensation = vjp(compute_compensation_partial, cov2d) # type: ignore _, vjp_project_pix = vjp(project_pix_partial, means3d) # type: ignore _, vjp_compute_depth = vjp(compute_depth_partial, means3d) # type: ignore _v_cov2d = vjp_compute_cov2d_bounds(v_conics)[0] + _v_cov2d = _v_cov2d + vjp_compute_compensation(v_compensation)[0] _v_mean3d_cov2d, _v_cov3d = vjp_project_cov3d_ewa(_v_cov2d) _v_mean3d_xy = vjp_project_pix(v_xys)[0] _v_mean3d_depth = vjp_compute_depth(v_depths)[0]