Skip to content

Commit

Permalink
Fix v_mean3d in project_gaussians and v_conic calculation in *rasteri…
Browse files Browse the repository at this point in the history
…ze_backward_kernel (#139)
  • Loading branch information
jb-ye authored Mar 4, 2024
1 parent c45cbdc commit 94cbd12
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 17 deletions.
6 changes: 3 additions & 3 deletions gsplat/cuda/csrc/backward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ __global__ void nd_rasterize_backward_kernel(
v_alpha += T_final * ra * v_out_alpha;
const float v_sigma = -opac * vis * v_alpha;
v_conic_local = {0.5f * v_sigma * delta.x * delta.x,
0.5f * v_sigma * delta.x * delta.y,
v_sigma * delta.x * delta.y,
0.5f * v_sigma * delta.y * delta.y};
v_xy_local = {v_sigma * (conic.x * delta.x + conic.y * delta.y),
v_sigma * (conic.y * delta.x + conic.z * delta.y)};
Expand Down Expand Up @@ -286,8 +286,8 @@ __global__ void rasterize_backward_kernel(

const float v_sigma = -opac * vis * v_alpha;
v_conic_local = {0.5f * v_sigma * delta.x * delta.x,
0.5f * v_sigma * delta.x * delta.y,
0.5f * v_sigma * delta.y * delta.y};
v_sigma * delta.x * delta.y,
0.5f * v_sigma * delta.y * delta.y};
v_xy_local = {v_sigma * (conic.x * delta.x + conic.y * delta.y),
v_sigma * (conic.y * delta.x + conic.z * delta.y)};
v_opacity_local = vis * v_alpha;
Expand Down
16 changes: 8 additions & 8 deletions gsplat/cuda/csrc/helpers.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -121,19 +121,19 @@ inline __device__ float3 project_pix_vjp(
const float *mat, const float3 p, const dim3 img_size, const float2 v_xy
) {
// ROW MAJOR mat
float4 p_hom = transform_4x4(mat, p);
float rw = 1.f / (p_hom.w + 1e-6f);
float4 t = transform_4x4(mat, p);
float rw = 1.f / (t.w + 1e-6f);

float3 v_ndc = {0.5f * img_size.x * v_xy.x, 0.5f * img_size.y * v_xy.y};
float4 v_proj = {
v_ndc.x * rw, v_ndc.y * rw, 0., -(v_ndc.x + v_ndc.y) * rw * rw
float4 v_t = {
v_ndc.x * rw, v_ndc.y * rw, 0., -(v_ndc.x * t.x + v_ndc.y * t.y) * rw * rw
};
// df / d_world = df / d_cam * d_cam / d_world
// = v_proj * P[:3, :3]
// = v_t * mat[:3, :4]
return {
mat[0] * v_proj.x + mat[4] * v_proj.y + mat[8] * v_proj.z,
mat[1] * v_proj.x + mat[5] * v_proj.y + mat[9] * v_proj.z,
mat[2] * v_proj.x + mat[6] * v_proj.y + mat[10] * v_proj.z
mat[0] * v_t.x + mat[4] * v_t.y + mat[8] * v_t.z + mat[12] * v_t.w,
mat[1] * v_t.x + mat[5] * v_t.y + mat[9] * v_t.z + mat[13] * v_t.w,
mat[2] * v_t.x + mat[6] * v_t.y + mat[10] * v_t.z + mat[14] * v_t.w,
};
}

Expand Down
2 changes: 1 addition & 1 deletion gsplat/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.1.7"
__version__ = "0.1.8"
9 changes: 4 additions & 5 deletions tests/test_project_gaussians.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def test_project_gaussians_forward():
],
device=device,
)
viewmat[:3, :3] = _torch_impl.quat_to_rotmat(torch.randn(4))
projmat = projection_matrix(fx, fy, W, H)
fullmat = projmat @ viewmat
BLOCK_SIZE = 16
Expand Down Expand Up @@ -149,8 +150,8 @@ def test_project_gaussians_backward():
],
device=device,
)
viewmat[:3, :3] = _torch_impl.quat_to_rotmat(torch.randn(4))
projmat = projection_matrix(fx, fy, W, H)
# projmat = torch.eye(4, device=device)
fullmat = projmat @ viewmat

BLOCK_SIZE = 16
Expand Down Expand Up @@ -181,9 +182,7 @@ def test_project_gaussians_backward():
# Test backward pass

v_xys = torch.randn_like(xys)
# v_depths = torch.randn_like(depths)
v_depths = torch.zeros_like(depths)
# scale gradients by pixels to account for finite difference
v_depths = torch.randn_like(depths)
v_conics = torch.randn_like(conics)
v_cov2d, v_cov3d, v_mean3d, v_scale, v_quat = _C.project_gaussians_backward(
num_points,
Expand Down Expand Up @@ -275,7 +274,7 @@ def compute_depth_partial(mean3d):
rtol = 1e-5
check_close(v_cov2d, _v_cov2d, atol=atol, rtol=rtol)
check_close(v_cov3d, _v_cov3d, atol=atol, rtol=rtol)
check_close(v_mean3d[:, :2], _v_mean3d[:, :2], atol=atol, rtol=rtol)
check_close(v_mean3d[:, :], _v_mean3d[:, :], atol=atol, rtol=rtol)
check_close(v_scale, _v_scale, atol=atol, rtol=rtol)
check_close(v_quat, _v_quat, atol=atol, rtol=rtol)
print("passed project_gaussians_backward test")
Expand Down

0 comments on commit 94cbd12

Please sign in to comment.