diff --git a/gsplat/_torch_impl.py b/gsplat/_torch_impl.py index fe6a38be8..db1e01029 100644 --- a/gsplat/_torch_impl.py +++ b/gsplat/_torch_impl.py @@ -142,7 +142,7 @@ def eval_sh_bases_fast(basis_dim: int, dirs: torch.Tensor): result[..., 0] = 0.2820947917738781 if basis_dim <= 1: - return + return result x, y, z = dirs.unbind(-1) @@ -152,7 +152,7 @@ def eval_sh_bases_fast(basis_dim: int, dirs: torch.Tensor): result[..., 1] = fTmpA * y if basis_dim <= 4: - return + return result z2 = z * z fTmpB = -1.092548430592079 * z @@ -166,7 +166,7 @@ def eval_sh_bases_fast(basis_dim: int, dirs: torch.Tensor): result[..., 4] = fTmpA * fS1 if basis_dim <= 9: - return + return result fTmpC = -2.285228997322329 * z2 + 0.4570457994644658 fTmpB = 1.445305721320277 * z @@ -182,7 +182,7 @@ def eval_sh_bases_fast(basis_dim: int, dirs: torch.Tensor): result[..., 9] = fTmpA * fS2 if basis_dim <= 16: - return + return result fTmpD = z * (-4.683325804901025 * z2 + 2.007139630671868) fTmpC = 3.31161143515146 * z2 - 0.47308734787878