diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index 8e9cb4538..617ec189d 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -102,7 +102,7 @@ def train( self.means, self.scales, 1, - self.quats, + self.quats / self.quats.norm(dim=-1, keepdim=True), self.viewmat, self.focal, self.focal, diff --git a/gsplat/cuda/csrc/helpers.cuh b/gsplat/cuda/csrc/helpers.cuh index 0e665c2f2..4a8b3578b 100644 --- a/gsplat/cuda/csrc/helpers.cuh +++ b/gsplat/cuda/csrc/helpers.cuh @@ -148,13 +148,10 @@ inline __device__ float3 project_pix_vjp( inline __device__ glm::mat3 quat_to_rotmat(const float4 quat) { // quat to rotation matrix - float s = rsqrtf( - quat.w * quat.w + quat.x * quat.x + quat.y * quat.y + quat.z * quat.z - ); - float w = quat.x * s; - float x = quat.y * s; - float y = quat.z * s; - float z = quat.w * s; + float w = quat.x; + float x = quat.y; + float y = quat.z; + float z = quat.w; // glm matrices are column-major return glm::mat3( @@ -172,13 +169,10 @@ inline __device__ glm::mat3 quat_to_rotmat(const float4 quat) { inline __device__ float4 quat_to_rotmat_vjp(const float4 quat, const glm::mat3 v_R) { - float s = rsqrtf( - quat.w * quat.w + quat.x * quat.x + quat.y * quat.y + quat.z * quat.z - ); - float w = quat.x * s; - float x = quat.y * s; - float y = quat.z * s; - float z = quat.w * s; + float w = quat.x; + float x = quat.y; + float y = quat.z; + float z = quat.w; float4 v_quat; // v_R is COLUMN MAJOR diff --git a/gsplat/project_gaussians.py b/gsplat/project_gaussians.py index b41192c5b..e427f1d22 100644 --- a/gsplat/project_gaussians.py +++ b/gsplat/project_gaussians.py @@ -34,7 +34,7 @@ def project_gaussians( means3d (Tensor): xyzs of gaussians. scales (Tensor): scales of the gaussians. glob_scale (float): A global scaling factor applied to the scene. - quats (Tensor): rotations in quaternion [w,x,y,z] format. + quats (Tensor): rotations in normalized quaternion [w,x,y,z] format. viewmat (Tensor): view matrix for rendering. fx (float): focal length x. fy (float): focal length y. @@ -57,6 +57,7 @@ def project_gaussians( - **cov3d** (Tensor): 3D covariances. """ assert block_width > 1 and block_width <= 16, "block_width must be between 2 and 16" + assert (quats.norm(dim=-1) - 1 < 1e-6).all(), "quats must be normalized" return _ProjectGaussians.apply( means3d.contiguous(), scales.contiguous(),