diff --git a/gsplat/cuda/_torch_impl.py b/gsplat/cuda/_torch_impl.py index e73cdd86a..50f9e81c3 100644 --- a/gsplat/cuda/_torch_impl.py +++ b/gsplat/cuda/_torch_impl.py @@ -83,13 +83,17 @@ def _persp_proj( fx = Ks[..., 0, 0, None] # [C, 1] fy = Ks[..., 1, 1, None] # [C, 1] + cx = Ks[..., 0, 2, None] # [C, 1] + cy = Ks[..., 1, 2, None] # [C, 1] tan_fovx = 0.5 * width / fx # [C, 1] tan_fovy = 0.5 * height / fy # [C, 1] - lim_x = 1.3 * tan_fovx - lim_y = 1.3 * tan_fovy - tx = tz * torch.clamp(tx / tz, min=-lim_x, max=lim_x) - ty = tz * torch.clamp(ty / tz, min=-lim_y, max=lim_y) + lim_x_pos = (width - cx) / fx + 0.3 * tan_fovx + lim_x_neg = cx / fx + 0.3 * tan_fovx + lim_y_pos = (height - cy) / fy + 0.3 * tan_fovy + lim_y_neg = cy / fy + 0.3 * tan_fovy + tx = tz * torch.clamp(tx / tz, min=-lim_x_neg, max=lim_x_pos) + ty = tz * torch.clamp(ty / tz, min=-lim_y_neg, max=lim_y_pos) O = torch.zeros((C, N), device=means.device, dtype=means.dtype) J = torch.stack( diff --git a/gsplat/cuda/csrc/utils.cuh b/gsplat/cuda/csrc/utils.cuh index 7e23a0f8e..7519cc5c3 100644 --- a/gsplat/cuda/csrc/utils.cuh +++ b/gsplat/cuda/csrc/utils.cuh @@ -153,13 +153,15 @@ inline __device__ void persp_proj( T tan_fovx = 0.5f * width / fx; T tan_fovy = 0.5f * height / fy; - T lim_x = 1.3f * tan_fovx; - T lim_y = 1.3f * tan_fovy; + T lim_x_pos = (width - cx) / fx + 0.3f * tan_fovx; + T lim_x_neg = cx / fx + 0.3f * tan_fovx; + T lim_y_pos = (height - cy) / fy + 0.3f * tan_fovy; + T lim_y_neg = cy / fy + 0.3f * tan_fovy; T rz = 1.f / z; T rz2 = rz * rz; - T tx = z * min(lim_x, max(-lim_x, x * rz)); - T ty = z * min(lim_y, max(-lim_y, y * rz)); + T tx = z * min(lim_x_pos, max(-lim_x_neg, x * rz)); + T ty = z * min(lim_y_pos, max(-lim_y_neg, y * rz)); // mat3x2 is 3 columns x 2 rows. mat3x2 J = mat3x2(fx * rz, 0.f, // 1st column @@ -183,13 +185,15 @@ inline __device__ void persp_proj_vjp( T tan_fovx = 0.5f * width / fx; T tan_fovy = 0.5f * height / fy; - T lim_x = 1.3f * tan_fovx; - T lim_y = 1.3f * tan_fovy; + T lim_x_pos = (width - cx) / fx + 0.3f * tan_fovx; + T lim_x_neg = cx / fx + 0.3f * tan_fovx; + T lim_y_pos = (height - cy) / fy + 0.3f * tan_fovy; + T lim_y_neg = cy / fy + 0.3f * tan_fovy; T rz = 1.f / z; T rz2 = rz * rz; - T tx = z * min(lim_x, max(-lim_x, x * rz)); - T ty = z * min(lim_y, max(-lim_y, y * rz)); + T tx = z * min(lim_x_pos, max(-lim_x_neg, x * rz)); + T ty = z * min(lim_y_pos, max(-lim_y_neg, y * rz)); // mat3x2 is 3 columns x 2 rows. mat3x2 J = mat3x2(fx * rz, 0.f, // 1st column @@ -217,12 +221,12 @@ inline __device__ void persp_proj_vjp( v_cov2d * J * glm::transpose(cov3d) + glm::transpose(v_cov2d) * J * cov3d; // fov clipping - if (x * rz <= lim_x && x * rz >= -lim_x) { + if (x * rz <= lim_x_pos && x * rz >= -lim_x_neg) { v_mean3d.x += -fx * rz2 * v_J[2][0]; } else { v_mean3d.z += -fx * rz3 * v_J[2][0] * tx; } - if (y * rz <= lim_y && y * rz >= -lim_y) { + if (y * rz <= lim_y_pos && y * rz >= -lim_y_neg) { v_mean3d.y += -fy * rz2 * v_J[2][1]; } else { v_mean3d.z += -fy * rz3 * v_J[2][1] * ty; @@ -341,4 +345,4 @@ inline __device__ void add_blur_vjp(const T eps2d, const mat2 conic_blur, v_sqr_comp * (one_minus_sqr_comp * conic_blur[1][1] - eps2d * det_conic_blur); } -#endif // GSPLAT_CUDA_UTILS_H \ No newline at end of file +#endif // GSPLAT_CUDA_UTILS_H