From c45cbdc86fbf1b8b7dc41a33b6f741e9d08f08dc Mon Sep 17 00:00:00 2001 From: "J.Y" <132313008+jb-ye@users.noreply.github.com> Date: Wed, 28 Feb 2024 14:54:27 -0500 Subject: [PATCH] Fix backprop grad of cov2d and unit tests (#136) --- gsplat/_torch_impl.py | 11 ++++++++--- gsplat/cuda/csrc/backward.cu | 6 +++--- gsplat/cuda/csrc/helpers.cuh | 2 +- gsplat/version.py | 2 +- tests/test_cov2d_bounds.py | 6 ++++-- tests/test_project_gaussians.py | 6 +++--- 6 files changed, 20 insertions(+), 13 deletions(-) diff --git a/gsplat/_torch_impl.py b/gsplat/_torch_impl.py index e5a864c7e..e897eeb73 100644 --- a/gsplat/_torch_impl.py +++ b/gsplat/_torch_impl.py @@ -113,9 +113,9 @@ def eval_sh_bases(basis_dim: int, dirs: torch.Tensor): return result -def quat_to_rotmat(quat: Tensor) -> Tensor: +def normalized_quat_to_rotmat(quat: Tensor) -> Tensor: assert quat.shape[-1] == 4, quat.shape - w, x, y, z = torch.unbind(F.normalize(quat, dim=-1), dim=-1) + w, x, y, z = torch.unbind(quat, dim=-1) mat = torch.stack( [ 1 - 2 * (y**2 + z**2), @@ -133,11 +133,16 @@ def quat_to_rotmat(quat: Tensor) -> Tensor: return mat.reshape(quat.shape[:-1] + (3, 3)) +def quat_to_rotmat(quat: Tensor) -> Tensor: + assert quat.shape[-1] == 4, quat.shape + return normalized_quat_to_rotmat(F.normalize(quat, dim=-1)) + + def scale_rot_to_cov3d(scale: Tensor, glob_scale: float, quat: Tensor) -> Tensor: assert scale.shape[-1] == 3, scale.shape assert quat.shape[-1] == 4, quat.shape assert scale.shape[:-1] == quat.shape[:-1], (scale.shape, quat.shape) - R = quat_to_rotmat(quat) # (..., 3, 3) + R = normalized_quat_to_rotmat(quat) # (..., 3, 3) M = R * glob_scale * scale[..., None, :] # (..., 3, 3) # TODO: save upper right because symmetric return M @ M.transpose(-1, -2) # (..., 3, 3) diff --git a/gsplat/cuda/csrc/backward.cu b/gsplat/cuda/csrc/backward.cu index 1e0c3fa98..13090e4c9 100644 --- a/gsplat/cuda/csrc/backward.cu +++ b/gsplat/cuda/csrc/backward.cu @@ -497,9 +497,9 @@ __device__ void scale_rot_to_cov3d_vjp( // df/dW = G * XT, df/dX = WT * G glm::mat3 v_M = 2.f * v_V * M; // glm::mat3 v_S = glm::transpose(R) * v_M; - v_scale.x = (float)glm::dot(R[0], v_M[0]); - v_scale.y = (float)glm::dot(R[1], v_M[1]); - v_scale.z = (float)glm::dot(R[2], v_M[2]); + v_scale.x = (float)glm::dot(R[0], v_M[0]) * glob_scale; + v_scale.y = (float)glm::dot(R[1], v_M[1]) * glob_scale; + v_scale.z = (float)glm::dot(R[2], v_M[2]) * glob_scale; glm::mat3 v_R = v_M * S; v_quat = quat_to_rotmat_vjp(quat, v_R); diff --git a/gsplat/cuda/csrc/helpers.cuh b/gsplat/cuda/csrc/helpers.cuh index ef6902253..bcf69f13c 100644 --- a/gsplat/cuda/csrc/helpers.cuh +++ b/gsplat/cuda/csrc/helpers.cuh @@ -75,7 +75,7 @@ inline __device__ void cov2d_to_conic_vjp( // conic = inverse cov2d // df/d_cov2d = -conic * df/d_conic * conic glm::mat2 X = glm::mat2(conic.x, conic.y, conic.y, conic.z); - glm::mat2 G = glm::mat2(v_conic.x, v_conic.y, v_conic.y, v_conic.z); + glm::mat2 G = glm::mat2(v_conic.x, v_conic.y / 2.f, v_conic.y / 2.f, v_conic.z); glm::mat2 v_Sigma = -X * G * X; v_cov2d.x = v_Sigma[0][0]; v_cov2d.y = v_Sigma[1][0] + v_Sigma[0][1]; diff --git a/gsplat/version.py b/gsplat/version.py index 0a8da8825..f1380eede 100644 --- a/gsplat/version.py +++ b/gsplat/version.py @@ -1 +1 @@ -__version__ = "0.1.6" +__version__ = "0.1.7" diff --git a/tests/test_cov2d_bounds.py b/tests/test_cov2d_bounds.py index 8045d4d46..569933a1f 100644 --- a/tests/test_cov2d_bounds.py +++ b/tests/test_cov2d_bounds.py @@ -31,8 +31,10 @@ def test_compare_binding_to_pytorch(): radii = radii.squeeze(-1) - torch.testing.assert_close(conic[_mask], _conic[_mask]) - torch.testing.assert_close(radii[_mask], _radii[_mask]) + atol = 5e-4 + rtol = 1e-5 + torch.testing.assert_close(conic[_mask], _conic[_mask], atol=atol, rtol=rtol) + torch.testing.assert_close(radii[_mask], _radii[_mask], atol=atol, rtol=rtol) if __name__ == "__main__": diff --git a/tests/test_project_gaussians.py b/tests/test_project_gaussians.py index 156a3389f..d2da32c29 100644 --- a/tests/test_project_gaussians.py +++ b/tests/test_project_gaussians.py @@ -42,7 +42,7 @@ def test_project_gaussians_forward(): means3d = torch.randn((num_points, 3), device=device, requires_grad=True) scales = torch.rand((num_points, 3), device=device) + 0.2 - glob_scale = 1.0 + glob_scale = 0.1 quats = torch.randn((num_points, 4), device=device) quats /= torch.linalg.norm(quats, dim=-1, keepdim=True) @@ -131,7 +131,7 @@ def test_project_gaussians_backward(): means3d = torch.randn((num_points, 3), device=device, requires_grad=True) scales = torch.rand((num_points, 3), device=device) + 0.2 - glob_scale = 1.0 + glob_scale = 0.1 quats = torch.randn((num_points, 4), device=device) quats /= torch.linalg.norm(quats, dim=-1, keepdim=True) @@ -184,7 +184,7 @@ def test_project_gaussians_backward(): # v_depths = torch.randn_like(depths) v_depths = torch.zeros_like(depths) # scale gradients by pixels to account for finite difference - v_conics = torch.randn_like(conics) * 1e-3 + v_conics = torch.randn_like(conics) v_cov2d, v_cov3d, v_mean3d, v_scale, v_quat = _C.project_gaussians_backward( num_points, means3d,