diff --git a/gsplat/rendering.py b/gsplat/rendering.py index 054e73721..702b63bff 100644 --- a/gsplat/rendering.py +++ b/gsplat/rendering.py @@ -111,7 +111,7 @@ def rasterization( Args: means: The 3D centers of the Gaussians. [N, 3] - quats: The quaternions of the Gaussians. It's not required to be normalized. [N, 4] + quats: The quaternions of the Gaussians (wxyz convension). It's not required to be normalized. [N, 4] scales: The scales of the Gaussians. [N, 3] opacities: The opacities of the Gaussians. [N] colors: The colors of the Gaussians. [(C,) N, D] or [(C,) N, K, 3] for SH coefficients. diff --git a/gsplat/strategy/ops.py b/gsplat/strategy/ops.py index cbb3705c3..537bc3c9b 100644 --- a/gsplat/strategy/ops.py +++ b/gsplat/strategy/ops.py @@ -323,5 +323,5 @@ def op_sigmoid(x, k=100, x0=0.995): * (op_sigmoid(1 - opacities)).unsqueeze(-1) * scaler ) - noise = torch.bmm(covars, noise.unsqueeze(-1)).squeeze(-1) + noise = torch.einsum("bij,bj->bi", covars, noise) params["means"].add_(noise) diff --git a/gsplat/utils.py b/gsplat/utils.py index 842917400..8a967e352 100644 --- a/gsplat/utils.py +++ b/gsplat/utils.py @@ -6,7 +6,7 @@ def normalized_quat_to_rotmat(quat: Tensor) -> Tensor: """Convert normalized quaternion to rotation matrix. Args: - quat: Normalized quaternion (..., 4) + quat: Normalized quaternion in wxyz convension. (..., 4) Returns: Rotation matrix (..., 3, 3)