Skip to content

Commit

Permalink
Fix backprop grad of cov2d and unit tests (#136)
Browse files Browse the repository at this point in the history
  • Loading branch information
jb-ye authored Feb 28, 2024
1 parent 24215cb commit c45cbdc
Show file tree
Hide file tree
Showing 6 changed files with 20 additions and 13 deletions.
11 changes: 8 additions & 3 deletions gsplat/_torch_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,9 @@ def eval_sh_bases(basis_dim: int, dirs: torch.Tensor):
return result


def quat_to_rotmat(quat: Tensor) -> Tensor:
def normalized_quat_to_rotmat(quat: Tensor) -> Tensor:
assert quat.shape[-1] == 4, quat.shape
w, x, y, z = torch.unbind(F.normalize(quat, dim=-1), dim=-1)
w, x, y, z = torch.unbind(quat, dim=-1)
mat = torch.stack(
[
1 - 2 * (y**2 + z**2),
Expand All @@ -133,11 +133,16 @@ def quat_to_rotmat(quat: Tensor) -> Tensor:
return mat.reshape(quat.shape[:-1] + (3, 3))


def quat_to_rotmat(quat: Tensor) -> Tensor:
assert quat.shape[-1] == 4, quat.shape
return normalized_quat_to_rotmat(F.normalize(quat, dim=-1))


def scale_rot_to_cov3d(scale: Tensor, glob_scale: float, quat: Tensor) -> Tensor:
assert scale.shape[-1] == 3, scale.shape
assert quat.shape[-1] == 4, quat.shape
assert scale.shape[:-1] == quat.shape[:-1], (scale.shape, quat.shape)
R = quat_to_rotmat(quat) # (..., 3, 3)
R = normalized_quat_to_rotmat(quat) # (..., 3, 3)
M = R * glob_scale * scale[..., None, :] # (..., 3, 3)
# TODO: save upper right because symmetric
return M @ M.transpose(-1, -2) # (..., 3, 3)
Expand Down
6 changes: 3 additions & 3 deletions gsplat/cuda/csrc/backward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -497,9 +497,9 @@ __device__ void scale_rot_to_cov3d_vjp(
// df/dW = G * XT, df/dX = WT * G
glm::mat3 v_M = 2.f * v_V * M;
// glm::mat3 v_S = glm::transpose(R) * v_M;
v_scale.x = (float)glm::dot(R[0], v_M[0]);
v_scale.y = (float)glm::dot(R[1], v_M[1]);
v_scale.z = (float)glm::dot(R[2], v_M[2]);
v_scale.x = (float)glm::dot(R[0], v_M[0]) * glob_scale;
v_scale.y = (float)glm::dot(R[1], v_M[1]) * glob_scale;
v_scale.z = (float)glm::dot(R[2], v_M[2]) * glob_scale;

glm::mat3 v_R = v_M * S;
v_quat = quat_to_rotmat_vjp(quat, v_R);
Expand Down
2 changes: 1 addition & 1 deletion gsplat/cuda/csrc/helpers.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ inline __device__ void cov2d_to_conic_vjp(
// conic = inverse cov2d
// df/d_cov2d = -conic * df/d_conic * conic
glm::mat2 X = glm::mat2(conic.x, conic.y, conic.y, conic.z);
glm::mat2 G = glm::mat2(v_conic.x, v_conic.y, v_conic.y, v_conic.z);
glm::mat2 G = glm::mat2(v_conic.x, v_conic.y / 2.f, v_conic.y / 2.f, v_conic.z);
glm::mat2 v_Sigma = -X * G * X;
v_cov2d.x = v_Sigma[0][0];
v_cov2d.y = v_Sigma[1][0] + v_Sigma[0][1];
Expand Down
2 changes: 1 addition & 1 deletion gsplat/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.1.6"
__version__ = "0.1.7"
6 changes: 4 additions & 2 deletions tests/test_cov2d_bounds.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,10 @@ def test_compare_binding_to_pytorch():

radii = radii.squeeze(-1)

torch.testing.assert_close(conic[_mask], _conic[_mask])
torch.testing.assert_close(radii[_mask], _radii[_mask])
atol = 5e-4
rtol = 1e-5
torch.testing.assert_close(conic[_mask], _conic[_mask], atol=atol, rtol=rtol)
torch.testing.assert_close(radii[_mask], _radii[_mask], atol=atol, rtol=rtol)


if __name__ == "__main__":
Expand Down
6 changes: 3 additions & 3 deletions tests/test_project_gaussians.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def test_project_gaussians_forward():

means3d = torch.randn((num_points, 3), device=device, requires_grad=True)
scales = torch.rand((num_points, 3), device=device) + 0.2
glob_scale = 1.0
glob_scale = 0.1
quats = torch.randn((num_points, 4), device=device)
quats /= torch.linalg.norm(quats, dim=-1, keepdim=True)

Expand Down Expand Up @@ -131,7 +131,7 @@ def test_project_gaussians_backward():

means3d = torch.randn((num_points, 3), device=device, requires_grad=True)
scales = torch.rand((num_points, 3), device=device) + 0.2
glob_scale = 1.0
glob_scale = 0.1
quats = torch.randn((num_points, 4), device=device)
quats /= torch.linalg.norm(quats, dim=-1, keepdim=True)

Expand Down Expand Up @@ -184,7 +184,7 @@ def test_project_gaussians_backward():
# v_depths = torch.randn_like(depths)
v_depths = torch.zeros_like(depths)
# scale gradients by pixels to account for finite difference
v_conics = torch.randn_like(conics) * 1e-3
v_conics = torch.randn_like(conics)
v_cov2d, v_cov3d, v_mean3d, v_scale, v_quat = _C.project_gaussians_backward(
num_points,
means3d,
Expand Down

0 comments on commit c45cbdc

Please sign in to comment.