Skip to content

Commit

Permalink
Check for normalized quats. (#155)
Browse files Browse the repository at this point in the history
  • Loading branch information
Mxbonn authored Apr 2, 2024
1 parent 4609d3b commit c54fe8b
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 16 deletions.
2 changes: 1 addition & 1 deletion examples/simple_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def train(
self.means,
self.scales,
1,
self.quats,
self.quats / self.quats.norm(dim=-1, keepdim=True),
self.viewmat,
self.focal,
self.focal,
Expand Down
22 changes: 8 additions & 14 deletions gsplat/cuda/csrc/helpers.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -148,13 +148,10 @@ inline __device__ float3 project_pix_vjp(

inline __device__ glm::mat3 quat_to_rotmat(const float4 quat) {
// quat to rotation matrix
float s = rsqrtf(
quat.w * quat.w + quat.x * quat.x + quat.y * quat.y + quat.z * quat.z
);
float w = quat.x * s;
float x = quat.y * s;
float y = quat.z * s;
float z = quat.w * s;
float w = quat.x;
float x = quat.y;
float y = quat.z;
float z = quat.w;

// glm matrices are column-major
return glm::mat3(
Expand All @@ -172,13 +169,10 @@ inline __device__ glm::mat3 quat_to_rotmat(const float4 quat) {

inline __device__ float4
quat_to_rotmat_vjp(const float4 quat, const glm::mat3 v_R) {
float s = rsqrtf(
quat.w * quat.w + quat.x * quat.x + quat.y * quat.y + quat.z * quat.z
);
float w = quat.x * s;
float x = quat.y * s;
float y = quat.z * s;
float z = quat.w * s;
float w = quat.x;
float x = quat.y;
float y = quat.z;
float z = quat.w;

float4 v_quat;
// v_R is COLUMN MAJOR
Expand Down
3 changes: 2 additions & 1 deletion gsplat/project_gaussians.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def project_gaussians(
means3d (Tensor): xyzs of gaussians.
scales (Tensor): scales of the gaussians.
glob_scale (float): A global scaling factor applied to the scene.
quats (Tensor): rotations in quaternion [w,x,y,z] format.
quats (Tensor): rotations in normalized quaternion [w,x,y,z] format.
viewmat (Tensor): view matrix for rendering.
fx (float): focal length x.
fy (float): focal length y.
Expand All @@ -57,6 +57,7 @@ def project_gaussians(
- **cov3d** (Tensor): 3D covariances.
"""
assert block_width > 1 and block_width <= 16, "block_width must be between 2 and 16"
assert (quats.norm(dim=-1) - 1 < 1e-6).all(), "quats must be normalized"
return _ProjectGaussians.apply(
means3d.contiguous(),
scales.contiguous(),
Expand Down

0 comments on commit c54fe8b

Please sign in to comment.