From 2d69aaf704cc7208b0d7a90794d286c49d758e70 Mon Sep 17 00:00:00 2001 From: Justin Kerr Date: Wed, 27 Mar 2024 09:52:40 -0700 Subject: [PATCH] remove projmat from API (#149) * remove 'projmat' arg --- examples/simple_trainer.py | 1 - gsplat/_torch_impl.py | 18 ++++++++++-------- gsplat/project_gaussians.py | 10 +++------- tests/test_project_gaussians.py | 1 - 4 files changed, 13 insertions(+), 17 deletions(-) diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index 0f2e3d5e0..8e9cb4538 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -104,7 +104,6 @@ def train( 1, self.quats, self.viewmat, - self.viewmat, self.focal, self.focal, self.W / 2, diff --git a/gsplat/_torch_impl.py b/gsplat/_torch_impl.py index 9893603b3..5babc725c 100644 --- a/gsplat/_torch_impl.py +++ b/gsplat/_torch_impl.py @@ -118,15 +118,15 @@ def normalized_quat_to_rotmat(quat: Tensor) -> Tensor: w, x, y, z = torch.unbind(quat, dim=-1) mat = torch.stack( [ - 1 - 2 * (y**2 + z**2), + 1 - 2 * (y ** 2 + z ** 2), 2 * (x * y - w * z), 2 * (x * z + w * y), 2 * (x * y + w * z), - 1 - 2 * (x**2 + z**2), + 1 - 2 * (x ** 2 + z ** 2), 2 * (y * z - w * x), 2 * (x * z - w * y), 2 * (y * z + w * x), - 1 - 2 * (x**2 + y**2), + 1 - 2 * (x ** 2 + y ** 2), ], dim=-1, ) @@ -165,7 +165,7 @@ def project_cov3d_ewa( t = torch.einsum("...ij,...j->...i", W, mean3d) + p # (..., 3) rz = 1.0 / t[..., 2] # (...,) - rz2 = rz**2 # (...,) + rz2 = rz ** 2 # (...,) lim_x = 1.3 * torch.tensor([tan_fovx], device=mean3d.device) lim_y = 1.3 * torch.tensor([tan_fovy], device=mean3d.device) @@ -220,8 +220,8 @@ def compute_cov2d_bounds(cov2d_mat: Tensor): dim=-1, ) # (..., 3) b = (cov2d[..., 0, 0] + cov2d[..., 1, 1]) / 2 # (...,) - v1 = b + torch.sqrt(torch.clamp(b**2 - det, min=0.1)) # (...,) - v2 = b - torch.sqrt(torch.clamp(b**2 - det, min=0.1)) # (...,) + v1 = b + torch.sqrt(torch.clamp(b ** 2 - det, min=0.1)) # (...,) + v2 = b - torch.sqrt(torch.clamp(b ** 2 - det, min=0.1)) # (...,) radius = torch.ceil(3.0 * torch.sqrt(torch.max(v1, v2))) # (...,) radius_all = torch.zeros(*cov2d_mat.shape[:-2], device=cov2d_mat.device) conic_all = torch.zeros(*cov2d_mat.shape[:-2], 3, device=cov2d_mat.device) @@ -229,15 +229,17 @@ def compute_cov2d_bounds(cov2d_mat: Tensor): conic_all[valid] = conic return conic_all, radius_all, valid + def project_pix(fxfy, p_view, center, eps=1e-6): fx, fy = fxfy cx, cy = center rw = 1.0 / (p_view[..., 2] + 1e-6) - p_proj = ( p_view[..., 0] * rw, p_view[..., 1] * rw ) - u, v = ( p_proj[0] * fx + cx, p_proj[1] * fy + cy ) + p_proj = (p_view[..., 0] * rw, p_view[..., 1] * rw) + u, v = (p_proj[0] * fx + cx, p_proj[1] * fy + cy) return torch.stack([u, v], dim=-1) + def clip_near_plane(p, viewmat, clip_thresh=0.01): R = viewmat[:3, :3] T = viewmat[:3, 3] diff --git a/gsplat/project_gaussians.py b/gsplat/project_gaussians.py index fbd7566e8..b41192c5b 100644 --- a/gsplat/project_gaussians.py +++ b/gsplat/project_gaussians.py @@ -16,7 +16,6 @@ def project_gaussians( glob_scale: float, quats: Float[Tensor, "*batch 4"], viewmat: Float[Tensor, "4 4"], - projmat: Optional[Float[Tensor, "4 4"]], fx: float, fy: float, cx: float, @@ -37,7 +36,6 @@ def project_gaussians( glob_scale (float): A global scaling factor applied to the scene. quats (Tensor): rotations in quaternion [w,x,y,z] format. viewmat (Tensor): view matrix for rendering. - projmat (Tensor): DEPRECATED and ignored. Set to None fx (float): focal length x. fy (float): focal length y. cx (float): principal point x. @@ -65,7 +63,6 @@ def project_gaussians( glob_scale, quats.contiguous(), viewmat.contiguous(), - None, fx, fy, cx, @@ -88,7 +85,6 @@ def forward( glob_scale: float, quats: Float[Tensor, "*batch 4"], viewmat: Float[Tensor, "4 4"], - projmat: Optional[Float[Tensor, "4 4"]], fx: float, fy: float, cx: float, @@ -227,7 +223,9 @@ def backward( # gradent w.r.t. view matrix rotation for j in range(3): for l in range(3): - v_viewmat[..., j, l] = torch.dot(v_mean3d_cam[..., j], means3d[..., l]) + v_viewmat[..., j, l] = torch.dot( + v_mean3d_cam[..., j], means3d[..., l] + ) else: v_viewmat = None @@ -243,8 +241,6 @@ def backward( v_quat, # viewmat: Float[Tensor, "4 4"], v_viewmat, - # projmat: Float[Tensor, "4 4"], - None, # fx: float, None, # fy: float, diff --git a/tests/test_project_gaussians.py b/tests/test_project_gaussians.py index 3f0d30cbd..c08af4292 100644 --- a/tests/test_project_gaussians.py +++ b/tests/test_project_gaussians.py @@ -77,7 +77,6 @@ def test_project_gaussians_forward(): glob_scale, quats, viewmat, - None, # deprecated projmat/fullmat fx, fy, cx,