From b08e0833a99163f7f20fcc636ae2f4608ac68b07 Mon Sep 17 00:00:00 2001 From: jb-ye Date: Thu, 1 Aug 2024 19:43:22 +0000 Subject: [PATCH] Fix projection for images with non-centered camera (e.g. crops) --- gsplat/_torch_impl.py | 16 ++++++++-------- gsplat/cuda/csrc/backward.cu | 20 ++++++++++++++++++++ gsplat/cuda/csrc/backward.cuh | 4 ++++ gsplat/cuda/csrc/forward.cu | 19 +++++++++++-------- gsplat/cuda/csrc/forward.cuh | 6 ++++-- gsplat/version.py | 2 +- tests/test_project_gaussians.py | 4 +++- 7 files changed, 51 insertions(+), 20 deletions(-) diff --git a/gsplat/_torch_impl.py b/gsplat/_torch_impl.py index 278c542ed..68c01d38d 100644 --- a/gsplat/_torch_impl.py +++ b/gsplat/_torch_impl.py @@ -245,8 +245,8 @@ def project_cov3d_ewa( viewmat: Tensor, fx: float, fy: float, - tan_fovx: float, - tan_fovy: float, + lim_x: Tuple[float, float], + lim_y: Tuple[float, float], is_valid: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor]: assert mean3d.shape[-1] == 3, mean3d.shape @@ -263,10 +263,8 @@ def project_cov3d_ewa( rz = 1.0 / t[..., 2] # (...,) rz2 = rz**2 # (...,) - lim_x = 1.3 * torch.tensor([tan_fovx], device=mean3d.device) - lim_y = 1.3 * torch.tensor([tan_fovy], device=mean3d.device) - x_clamp = t[..., 2] * torch.clamp(t[..., 0] * rz, min=-lim_x, max=lim_x) - y_clamp = t[..., 2] * torch.clamp(t[..., 1] * rz, min=-lim_y, max=lim_y) + x_clamp = t[..., 2] * torch.clamp(t[..., 0] * rz, min=-lim_x[1], max=lim_x[0]) + y_clamp = t[..., 2] * torch.clamp(t[..., 1] * rz, min=-lim_y[1], max=lim_y[0]) t = torch.stack([x_clamp, y_clamp, t[..., 2]], dim=-1) O = torch.zeros_like(rz) @@ -352,7 +350,7 @@ def project_pix(fxfy, p_view, center, eps=1e-6): return torch.stack([u, v], dim=-1) -def clip_near_plane(p, viewmat, clip_thresh=0.01): +def clip_near_plane(p, viewmat, clip_thresh=0.01) -> Tuple[Tensor, Tensor]: R = viewmat[:3, :3] T = viewmat[:3, 3] p_view = torch.einsum("ij,nj->ni", R, p) + T[None] @@ -404,10 +402,12 @@ def project_gaussians_forward( fx, fy, cx, cy = intrins tan_fovx = 0.5 * img_size[0] / fx tan_fovy = 0.5 * img_size[1] / fy + lim_x = ((img_size[0] - cx) / fx + 0.3 * tan_fovx, cx / fx + 0.3 * tan_fovx) + lim_y = ((img_size[1] - cy) / fy + 0.3 * tan_fovy, cy / fy + 0.3 * tan_fovy) p_view, is_close = clip_near_plane(means3d, viewmat, clip_thresh) cov3d = scale_rot_to_cov3d(scales, glob_scale, quats) cov2d, compensation = project_cov3d_ewa( - means3d, cov3d, viewmat, fx, fy, tan_fovx, tan_fovy, ~is_close + means3d, cov3d, viewmat, fx, fy, lim_x, lim_y, ~is_close ) conic, radius, det_valid = compute_cov2d_bounds(cov2d, ~is_close) xys = project_pix((fx, fy), p_view, (cx, cy)) diff --git a/gsplat/cuda/csrc/backward.cu b/gsplat/cuda/csrc/backward.cu index 908724b35..4cbd5c30c 100644 --- a/gsplat/cuda/csrc/backward.cu +++ b/gsplat/cuda/csrc/backward.cu @@ -357,6 +357,10 @@ __global__ void project_gaussians_backward_kernel( float3 p_world = means3d[idx]; float fx = intrins.x; float fy = intrins.y; + float cx = intrins.z; + float cy = intrins.w; + float tan_fovx = 0.5 * img_size.x / fx; + float tan_fovy = 0.5 * img_size.y / fy; float3 p_view = transform_4x3(viewmat, p_world); // get v_mean3d from v_xy v_mean3d[idx] = transform_4x3_rot_only_transposed( @@ -375,12 +379,20 @@ __global__ void project_gaussians_backward_kernel( cov2d_to_conic_vjp(conics[idx], v_conic[idx], v_cov2d[idx]); cov2d_to_compensation_vjp(compensation[idx], conics[idx], v_compensation[idx], v_cov2d[idx]); // get v_cov3d (and v_mean3d contribution) + float lim_x_pos = (img_size.x - cx) / fx + 0.3f * tan_fovx; + float lim_x_neg = cx / fx + 0.3f * tan_fovx; + float lim_y_pos = (img_size.y - cy) / fy + 0.3f * tan_fovy; + float lim_y_neg = cy / fy + 0.3f * tan_fovy; project_cov3d_ewa_vjp( p_world, &(cov3d[6 * idx]), viewmat, fx, fy, + lim_x_pos, + lim_x_neg, + lim_y_pos, + lim_y_neg, v_cov2d[idx], v_mean3d[idx], &(v_cov3d[6 * idx]) @@ -403,6 +415,10 @@ __device__ void project_cov3d_ewa_vjp( const float* __restrict__ viewmat, const float fx, const float fy, + const float lim_x_pos, + const float lim_x_neg, + const float lim_y_pos, + const float lim_y_neg, const float3& __restrict__ v_cov2d, float3& __restrict__ v_mean3d, float* __restrict__ v_cov3d @@ -418,6 +434,10 @@ __device__ void project_cov3d_ewa_vjp( // clang-format on glm::vec3 p = glm::vec3(viewmat[3], viewmat[7], viewmat[11]); glm::vec3 t = W * glm::vec3(mean3d.x, mean3d.y, mean3d.z) + p; + + t.x = t.z * std::min(lim_x_pos, std::max(-lim_x_neg, t.x / t.z)); + t.y = t.z * std::min(lim_y_pos, std::max(-lim_y_neg, t.y / t.z)); + float rz = 1.f / t.z; float rz2 = rz * rz; diff --git a/gsplat/cuda/csrc/backward.cuh b/gsplat/cuda/csrc/backward.cuh index 9e67a8d2a..06788e8ca 100644 --- a/gsplat/cuda/csrc/backward.cuh +++ b/gsplat/cuda/csrc/backward.cuh @@ -79,6 +79,10 @@ __device__ void project_cov3d_ewa_vjp( const float *viewmat, const float fx, const float fy, + const float lim_x_pos, + const float lim_x_neg, + const float lim_y_pos, + const float lim_y_neg, const float3 &v_cov2d, float3 &v_mean3d, float *v_cov3d diff --git a/gsplat/cuda/csrc/forward.cu b/gsplat/cuda/csrc/forward.cu index 4b2c2edf0..e9cfb5912 100644 --- a/gsplat/cuda/csrc/forward.cu +++ b/gsplat/cuda/csrc/forward.cu @@ -63,10 +63,14 @@ __global__ void project_gaussians_forward_kernel( float cy = intrins.w; float tan_fovx = 0.5 * img_size.x / fx; float tan_fovy = 0.5 * img_size.y / fy; + float lim_x_pos = (img_size.x - cx) / fx + 0.3f * tan_fovx; + float lim_x_neg = cx / fx + 0.3f * tan_fovx; + float lim_y_pos = (img_size.y - cy) / fy + 0.3f * tan_fovy; + float lim_y_neg = cy / fy + 0.3f * tan_fovy; float3 cov2d; float comp; project_cov3d_ewa( - p_world, cur_cov3d, viewmat, fx, fy, tan_fovx, tan_fovy, + p_world, cur_cov3d, viewmat, fx, fy, lim_x_pos, lim_x_neg, lim_y_pos, lim_y_neg, cov2d, comp ); // printf("cov2d %d, %.2f %.2f %.2f\n", idx, cov2d.x, cov2d.y, cov2d.z); @@ -425,8 +429,10 @@ __device__ void project_cov3d_ewa( const float* __restrict__ viewmat, const float fx, const float fy, - const float tan_fovx, - const float tan_fovy, + const float lim_x_pos, + const float lim_x_neg, + const float lim_y_pos, + const float lim_y_neg, float3 &cov2d, float &compensation ) { @@ -447,11 +453,8 @@ __device__ void project_cov3d_ewa( glm::vec3 p = glm::vec3(viewmat[3], viewmat[7], viewmat[11]); glm::vec3 t = W * glm::vec3(mean3d.x, mean3d.y, mean3d.z) + p; - // clip so that the covariance - float lim_x = 1.3f * tan_fovx; - float lim_y = 1.3f * tan_fovy; - t.x = t.z * std::min(lim_x, std::max(-lim_x, t.x / t.z)); - t.y = t.z * std::min(lim_y, std::max(-lim_y, t.y / t.z)); + t.x = t.z * std::min(lim_x_pos, std::max(-lim_x_neg, t.x / t.z)); + t.y = t.z * std::min(lim_y_pos, std::max(-lim_y_neg, t.y / t.z)); float rz = 1.f / t.z; float rz2 = rz * rz; diff --git a/gsplat/cuda/csrc/forward.cuh b/gsplat/cuda/csrc/forward.cuh index 0d17e3ba9..02325ee91 100644 --- a/gsplat/cuda/csrc/forward.cuh +++ b/gsplat/cuda/csrc/forward.cuh @@ -64,8 +64,10 @@ __device__ void project_cov3d_ewa( const float *viewmat, const float fx, const float fy, - const float tan_fovx, - const float tan_fovy, + const float lim_x_pos, + const float lim_x_neg, + const float lim_y_pos, + const float lim_y_neg, float3 &cov2d, float &comp ); diff --git a/gsplat/version.py b/gsplat/version.py index 74acd0efb..3cb7d95ef 100644 --- a/gsplat/version.py +++ b/gsplat/version.py @@ -1 +1 @@ -__version__ = "0.1.12" +__version__ = "0.1.13" diff --git a/tests/test_project_gaussians.py b/tests/test_project_gaussians.py index c08af4292..0de02c3e3 100644 --- a/tests/test_project_gaussians.py +++ b/tests/test_project_gaussians.py @@ -217,13 +217,15 @@ def project_cov3d_ewa_partial(mean3d, cov3d): """ tan_fovx = 0.5 * W / fx tan_fovy = 0.5 * H / fy + lim_x = ((W - cx) / fx + 0.3 * tan_fovx, cx / fx + 0.3 * tan_fovx) + lim_y = ((H - cy) / fy + 0.3 * tan_fovy, cy / fy + 0.3 * tan_fovy) cov3d_mat = torch.zeros(*cov3d.shape[:-1], 3, 3, device=device) i, j = torch.triu_indices(3, 3) cov3d_mat[..., i, j] = cov3d cov3d_mat[..., [1, 2, 2], [0, 0, 1]] = cov3d[..., [1, 2, 4]] cov2d, _ = _torch_impl.project_cov3d_ewa( - mean3d, cov3d_mat, viewmat, fx, fy, tan_fovx, tan_fovy + mean3d, cov3d_mat, viewmat, fx, fy, lim_x, lim_y ) ii, jj = torch.triu_indices(2, 2) return cov2d[..., ii, jj]