Skip to content

Commit

Permalink
Implement approximate gradient for viewmat (#127)
Browse files Browse the repository at this point in the history
Allows camera pose optimization in Nerfstudio
  • Loading branch information
oseiskar authored Mar 26, 2024
1 parent 797a363 commit 1516b9d
Showing 1 changed file with 37 additions and 1 deletion.
38 changes: 37 additions & 1 deletion gsplat/project_gaussians.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from typing import Optional, Tuple

import torch
from jaxtyping import Float
from torch import Tensor
from torch.autograd import Function
Expand Down Expand Up @@ -195,6 +196,41 @@ def backward(
v_compensation,
)

if viewmat.requires_grad:
v_viewmat = torch.zeros_like(viewmat)
R = viewmat[..., :3, :3]

# Denote ProjectGaussians for a single Gaussian (mean3d, q, s)
# viemwat = [R, t] as:
#
# f(mean3d, q, s, R, t, intrinsics)
# = g(R @ mean3d + t,
# R @ cov3d_world(q, s) @ R^T ))
#
# Then, the Jacobian w.r.t., t is:
#
# d f / d t = df / d mean3d @ R^T
#
# and, in the context of fine tuning camera poses, it is reasonable
# to assume that
#
# d f / d R_ij =~ \sum_l d f / d t_l * d (R @ mean3d)_l / d R_ij
# = d f / d_t_i * mean3d[j]
#
# Gradients for R and t can then be obtained by summing over
# all the Gaussians.
v_mean3d_cam = torch.matmul(v_mean3d, R.transpose(-1, -2))

# gradient w.r.t. view matrix translation
v_viewmat[..., :3, 3] = v_mean3d_cam.sum(-2)

# gradent w.r.t. view matrix rotation
for j in range(3):
for l in range(3):
v_viewmat[..., j, l] = torch.dot(v_mean3d_cam[..., j], means3d[..., l])
else:
v_viewmat = None

# Return a gradient for each input.
return (
# means3d: Float[Tensor, "*batch 3"],
Expand All @@ -206,7 +242,7 @@ def backward(
# quats: Float[Tensor, "*batch 4"],
v_quat,
# viewmat: Float[Tensor, "4 4"],
None,
v_viewmat,
# projmat: Float[Tensor, "4 4"],
None,
# fx: float,
Expand Down

0 comments on commit 1516b9d

Please sign in to comment.