diff --git a/gsplat/cuda/_torch_impl.py b/gsplat/cuda/_torch_impl.py index 1f9c2405a..2585e36b3 100644 --- a/gsplat/cuda/_torch_impl.py +++ b/gsplat/cuda/_torch_impl.py @@ -137,7 +137,9 @@ def _ortho_proj( J = torch.stack([fx, O, O, O, fy, O], dim=-1).reshape(C, 1, 2, 3).repeat(1, N, 1, 1) cov2d = torch.einsum("...ij,...jk,...kl->...il", J, covars, J.transpose(-1, -2)) - means2d = means[..., :2] * Ks[:, None, [0, 1], [0, 1]] + Ks[:, None, [0, 1], [2, 2]] # [C, N, 2] + means2d = ( + means[..., :2] * Ks[:, None, [0, 1], [0, 1]] + Ks[:, None, [0, 1], [2, 2]] + ) # [C, N, 2] return means2d, cov2d # [C, N, 2], [C, N, 2, 2]