From 4609d3bfa3fddb9f2c50af5250a677db67cbf49e Mon Sep 17 00:00:00 2001 From: "J.Y" <132313008+jb-ye@users.noreply.github.com> Date: Tue, 2 Apr 2024 09:58:49 -0400 Subject: [PATCH] Fix missing +0.5 in calculating uv coordinate (#151) --- gsplat/_torch_impl.py | 12 ++++++------ gsplat/cuda/csrc/backward.cu | 8 ++++---- gsplat/cuda/csrc/forward.cu | 8 ++++---- gsplat/version.py | 2 +- 4 files changed, 15 insertions(+), 15 deletions(-) diff --git a/gsplat/_torch_impl.py b/gsplat/_torch_impl.py index 5babc725c..6f4f936f6 100644 --- a/gsplat/_torch_impl.py +++ b/gsplat/_torch_impl.py @@ -118,15 +118,15 @@ def normalized_quat_to_rotmat(quat: Tensor) -> Tensor: w, x, y, z = torch.unbind(quat, dim=-1) mat = torch.stack( [ - 1 - 2 * (y ** 2 + z ** 2), + 1 - 2 * (y**2 + z**2), 2 * (x * y - w * z), 2 * (x * z + w * y), 2 * (x * y + w * z), - 1 - 2 * (x ** 2 + z ** 2), + 1 - 2 * (x**2 + z**2), 2 * (y * z - w * x), 2 * (x * z - w * y), 2 * (y * z + w * x), - 1 - 2 * (x ** 2 + y ** 2), + 1 - 2 * (x**2 + y**2), ], dim=-1, ) @@ -165,7 +165,7 @@ def project_cov3d_ewa( t = torch.einsum("...ij,...j->...i", W, mean3d) + p # (..., 3) rz = 1.0 / t[..., 2] # (...,) - rz2 = rz ** 2 # (...,) + rz2 = rz**2 # (...,) lim_x = 1.3 * torch.tensor([tan_fovx], device=mean3d.device) lim_y = 1.3 * torch.tensor([tan_fovy], device=mean3d.device) @@ -220,8 +220,8 @@ def compute_cov2d_bounds(cov2d_mat: Tensor): dim=-1, ) # (..., 3) b = (cov2d[..., 0, 0] + cov2d[..., 1, 1]) / 2 # (...,) - v1 = b + torch.sqrt(torch.clamp(b ** 2 - det, min=0.1)) # (...,) - v2 = b - torch.sqrt(torch.clamp(b ** 2 - det, min=0.1)) # (...,) + v1 = b + torch.sqrt(torch.clamp(b**2 - det, min=0.1)) # (...,) + v2 = b - torch.sqrt(torch.clamp(b**2 - det, min=0.1)) # (...,) radius = torch.ceil(3.0 * torch.sqrt(torch.max(v1, v2))) # (...,) radius_all = torch.zeros(*cov2d_mat.shape[:-2], device=cov2d_mat.device) conic_all = torch.zeros(*cov2d_mat.shape[:-2], 3, device=cov2d_mat.device) diff --git a/gsplat/cuda/csrc/backward.cu b/gsplat/cuda/csrc/backward.cu index f02853659..8b06d0be1 100644 --- a/gsplat/cuda/csrc/backward.cu +++ b/gsplat/cuda/csrc/backward.cu @@ -44,8 +44,8 @@ __global__ void nd_rasterize_backward_kernel( int32_t tile_id = blockIdx.y * tile_bounds.x + blockIdx.x; unsigned i = blockIdx.y * blockDim.y + threadIdx.y; unsigned j = blockIdx.x * blockDim.x + threadIdx.x; - float px = (float)j; - float py = (float)i; + float px = (float)j + 0.5; + float py = (float)i + 0.5; const int32_t pix_id = min(i * img_size.x + j, img_size.x * img_size.y - 1); // keep not rasterizing threads around for reading data @@ -159,8 +159,8 @@ __global__ void rasterize_backward_kernel( unsigned j = block.group_index().x * block.group_dim().x + block.thread_index().x; - const float px = (float)j; - const float py = (float)i; + const float px = (float)j + 0.5; + const float py = (float)i + 0.5; // clamp this value to the last pixel const int32_t pix_id = min(i * img_size.x + j, img_size.x * img_size.y - 1); diff --git a/gsplat/cuda/csrc/forward.cu b/gsplat/cuda/csrc/forward.cu index 8ef342dc7..4b2c2edf0 100644 --- a/gsplat/cuda/csrc/forward.cu +++ b/gsplat/cuda/csrc/forward.cu @@ -195,8 +195,8 @@ __global__ void nd_rasterize_forward( unsigned j = block.group_index().x * block.group_dim().x + block.thread_index().x; - float px = (float)j; - float py = (float)i; + float px = (float)j + 0.5; + float py = (float)i + 0.5; int32_t pix_id = i * img_size.x + j; // keep not rasterizing threads around for reading data @@ -318,8 +318,8 @@ __global__ void rasterize_forward( unsigned j = block.group_index().x * block.group_dim().x + block.thread_index().x; - float px = (float)j; - float py = (float)i; + float px = (float)j + 0.5; + float py = (float)i + 0.5; int32_t pix_id = i * img_size.x + j; // return if out of bounds diff --git a/gsplat/version.py b/gsplat/version.py index c11f861af..569b1212f 100644 --- a/gsplat/version.py +++ b/gsplat/version.py @@ -1 +1 @@ -__version__ = "0.1.9" +__version__ = "0.1.10"