Skip to content

Commit

Permalink
add backprop grad for opacity compensation (#140)
Browse files Browse the repository at this point in the history
  • Loading branch information
jb-ye authored Mar 5, 2024
1 parent 94cbd12 commit fecca4f
Show file tree
Hide file tree
Showing 8 changed files with 63 additions and 5 deletions.
14 changes: 13 additions & 1 deletion gsplat/_torch_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,19 @@ def project_cov3d_ewa(
cov2d[..., 1, 1] = cov2d[..., 1, 1] + 0.3
det_blur = cov2d[..., 0, 0] * cov2d[..., 1, 1] - cov2d[..., 0, 1] * cov2d[..., 0, 1]
compensation = torch.sqrt(torch.clamp(det_orig / det_blur, min=0))
return cov2d[..., :2, :2], compensation.detach()
return cov2d[..., :2, :2], compensation


def compute_compensation(cov2d_mat: Tensor):
"""
params: cov2d matrix (*, 2, 2)
returns: compensation factor as calculated in project_cov3d_ewa
"""
det_denom = cov2d_mat[..., 0, 0] * cov2d_mat[..., 1, 1] - cov2d_mat[..., 0, 1] ** 2
det_nomin = (cov2d_mat[..., 0, 0] - 0.3) * (cov2d_mat[..., 1, 1] - 0.3) - cov2d_mat[
..., 0, 1
] ** 2
return torch.sqrt(torch.clamp(det_nomin / det_denom, min=0))


def compute_cov2d_bounds(cov2d_mat: Tensor):
Expand Down
3 changes: 3 additions & 0 deletions gsplat/cuda/csrc/backward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -331,9 +331,11 @@ __global__ void project_gaussians_backward_kernel(
const float* __restrict__ cov3d,
const int* __restrict__ radii,
const float3* __restrict__ conics,
const float* __restrict__ compensation,
const float2* __restrict__ v_xy,
const float* __restrict__ v_depth,
const float3* __restrict__ v_conic,
const float* __restrict__ v_compensation,
float3* __restrict__ v_cov2d,
float* __restrict__ v_cov3d,
float3* __restrict__ v_mean3d,
Expand Down Expand Up @@ -362,6 +364,7 @@ __global__ void project_gaussians_backward_kernel(

// get v_cov2d
cov2d_to_conic_vjp(conics[idx], v_conic[idx], v_cov2d[idx]);
cov2d_to_compensation_vjp(compensation[idx], conics[idx], v_compensation[idx], v_cov2d[idx]);
// get v_cov3d (and v_mean3d contribution)
project_cov3d_ewa_vjp(
p_world,
Expand Down
2 changes: 2 additions & 0 deletions gsplat/cuda/csrc/backward.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@ __global__ void project_gaussians_backward_kernel(
const float* __restrict__ cov3d,
const int* __restrict__ radii,
const float3* __restrict__ conics,
const float* __restrict__ compensation,
const float2* __restrict__ v_xy,
const float* __restrict__ v_depth,
const float3* __restrict__ v_conic,
const float* __restrict__ v_compensation,
float3* __restrict__ v_cov2d,
float* __restrict__ v_cov3d,
float3* __restrict__ v_mean3d,
Expand Down
8 changes: 6 additions & 2 deletions gsplat/cuda/csrc/bindings.cu
Original file line number Diff line number Diff line change
Expand Up @@ -225,10 +225,12 @@ project_gaussians_backward_tensor(
torch::Tensor &cov3d,
torch::Tensor &radii,
torch::Tensor &conics,
torch::Tensor &compensation,
torch::Tensor &v_xy,
torch::Tensor &v_depth,
torch::Tensor &v_conic
) {
torch::Tensor &v_conic,
torch::Tensor &v_compensation
){
DEVICE_GUARD(means3d);
dim3 img_size_dim3;
img_size_dim3.x = img_width;
Expand Down Expand Up @@ -265,9 +267,11 @@ project_gaussians_backward_tensor(
cov3d.contiguous().data_ptr<float>(),
radii.contiguous().data_ptr<int32_t>(),
(float3 *)conics.contiguous().data_ptr<float>(),
(float *)compensation.contiguous().data_ptr<float>(),
(float2 *)v_xy.contiguous().data_ptr<float>(),
v_depth.contiguous().data_ptr<float>(),
(float3 *)v_conic.contiguous().data_ptr<float>(),
(float *)v_compensation.contiguous().data_ptr<float>(),
// Outputs.
(float3 *)v_cov2d.contiguous().data_ptr<float>(),
v_cov3d.contiguous().data_ptr<float>(),
Expand Down
4 changes: 3 additions & 1 deletion gsplat/cuda/csrc/bindings.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,11 @@ project_gaussians_backward_tensor(
torch::Tensor &cov3d,
torch::Tensor &radii,
torch::Tensor &conics,
torch::Tensor &compensation,
torch::Tensor &v_xy,
torch::Tensor &v_depth,
torch::Tensor &v_conic
torch::Tensor &v_conic,
torch::Tensor &v_compensation
);


Expand Down
15 changes: 15 additions & 0 deletions gsplat/cuda/csrc/helpers.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,21 @@ inline __device__ void cov2d_to_conic_vjp(
v_cov2d.z = v_Sigma[1][1];
}

inline __device__ void cov2d_to_compensation_vjp(
const float compensation, const float3 &conic, const float v_compensation, float3 &v_cov2d
) {
// comp = sqrt(det(cov2d - 0.3 I) / det(cov2d))
// conic = inverse(cov2d)
// df / d_cov2d = df / d comp * 0.5 / comp * [ d comp^2 / d cov2d ]
// d comp^2 / d cov2d = (1 - comp^2) * conic - 0.3 I * det(conic)
float inv_det = conic.x * conic.z - conic.y * conic.y;
float one_minus_sqr_comp = 1 - compensation * compensation;
float v_sqr_comp = v_compensation * 0.5 / (compensation + 1e-6);
v_cov2d.x += v_sqr_comp * (one_minus_sqr_comp * conic.x - 0.3 * inv_det);
v_cov2d.y += 2 * v_sqr_comp * (one_minus_sqr_comp * conic.y);
v_cov2d.z += v_sqr_comp * (one_minus_sqr_comp * conic.z - 0.3 * inv_det);
}

// helper for applying R * p + T, expect mat to be ROW MAJOR
inline __device__ float3 transform_4x3(const float *mat, const float3 p) {
float3 out = {
Expand Down
4 changes: 4 additions & 0 deletions gsplat/project_gaussians.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def forward(
cov3d,
radii,
conics,
compensation,
)

return (xys, depths, radii, conics, compensation, num_tiles_hit, cov3d)
Expand All @@ -171,6 +172,7 @@ def backward(
cov3d,
radii,
conics,
compensation,
) = ctx.saved_tensors

(v_cov2d, v_cov3d, v_mean3d, v_scale, v_quat) = _C.project_gaussians_backward(
Expand All @@ -190,9 +192,11 @@ def backward(
cov3d,
radii,
conics,
compensation,
v_xys,
v_depths,
v_conics,
v_compensation,
)

# Return a gradient for each input.
Expand Down
18 changes: 17 additions & 1 deletion tests/test_project_gaussians.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def test_project_gaussians_backward():
depths,
radii,
conics,
_,
compensation,
_,
masks,
) = _torch_impl.project_gaussians_forward(
Expand All @@ -184,6 +184,8 @@ def test_project_gaussians_backward():
v_xys = torch.randn_like(xys)
v_depths = torch.randn_like(depths)
v_conics = torch.randn_like(conics)
# compensation gradients
v_compensation = torch.randn_like(compensation)
v_cov2d, v_cov3d, v_mean3d, v_scale, v_quat = _C.project_gaussians_backward(
num_points,
means3d,
Expand All @@ -201,9 +203,11 @@ def test_project_gaussians_backward():
cov3d,
radii,
conics,
compensation,
v_xys,
v_depths,
v_conics,
v_compensation,
)

def scale_rot_to_cov3d_partial(scale, quat):
Expand Down Expand Up @@ -243,6 +247,16 @@ def compute_cov2d_bounds_partial(cov2d):
conic, _, _ = _torch_impl.compute_cov2d_bounds(cov2d_mat)
return conic

def compute_compensation_partial(cov2d):
"""
cov2d (upper tri) (*, 3) -> compensation (*, 1)
"""
cov2d_mat = torch.zeros(*cov2d.shape[:-1], 2, 2, device=device)
i, j = torch.triu_indices(2, 2)
cov2d_mat[..., i, j] = cov2d
cov2d_mat[..., 1, 0] = cov2d[..., 1]
return _torch_impl.compute_compensation(cov2d_mat)

def project_pix_partial(mean3d):
"""
mean3d (*, 3) -> xy (*, 2)
Expand All @@ -260,10 +274,12 @@ def compute_depth_partial(mean3d):
_, vjp_scale_rot_to_cov3d = vjp(scale_rot_to_cov3d_partial, scales, quats) # type: ignore
_, vjp_project_cov3d_ewa = vjp(project_cov3d_ewa_partial, means3d, cov3d) # type: ignore
_, vjp_compute_cov2d_bounds = vjp(compute_cov2d_bounds_partial, cov2d) # type: ignore
_, vjp_compute_compensation = vjp(compute_compensation_partial, cov2d) # type: ignore
_, vjp_project_pix = vjp(project_pix_partial, means3d) # type: ignore
_, vjp_compute_depth = vjp(compute_depth_partial, means3d) # type: ignore

_v_cov2d = vjp_compute_cov2d_bounds(v_conics)[0]
_v_cov2d = _v_cov2d + vjp_compute_compensation(v_compensation)[0]
_v_mean3d_cov2d, _v_cov3d = vjp_project_cov3d_ewa(_v_cov2d)
_v_mean3d_xy = vjp_project_pix(v_xys)[0]
_v_mean3d_depth = vjp_compute_depth(v_depths)[0]
Expand Down

0 comments on commit fecca4f

Please sign in to comment.