Skip to content

Commit

Permalink
Added orthographic projection
Browse files Browse the repository at this point in the history
  • Loading branch information
VladislavZavadskyy committed Aug 21, 2024
1 parent 45d196a commit 9fbf1e2
Show file tree
Hide file tree
Showing 15 changed files with 379 additions and 106 deletions.
2 changes: 1 addition & 1 deletion docs/source/apis/utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Below are the basic functions that supports the rasterization.

.. autofunction:: quat_scale_to_covar_preci

.. autofunction:: persp_proj
.. autofunction:: proj

.. autofunction:: fully_fused_projection

Expand Down
4 changes: 2 additions & 2 deletions gsplat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
fully_fused_projection,
isect_offset_encode,
isect_tiles,
persp_proj,
proj,
quat_scale_to_covar_preci,
rasterize_to_indices_in_range,
rasterize_to_pixels,
Expand Down Expand Up @@ -110,7 +110,7 @@ def get_tile_bin_edges(*args, **kwargs):
"spherical_harmonics",
"isect_offset_encode",
"isect_tiles",
"persp_proj",
"proj",
"fully_fused_projection",
"quat_scale_to_covar_preci",
"rasterize_to_pixels",
Expand Down
45 changes: 44 additions & 1 deletion gsplat/cuda/_torch_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,43 @@ def _persp_proj(
return means2d, cov2d # [C, N, 2], [C, N, 2, 2]


def _ortho_proj(
means: Tensor, # [C, N, 3]
covars: Tensor, # [C, N, 3, 3]
Ks: Tensor, # [C, 3, 3]
width: int,
height: int,
) -> Tuple[Tensor, Tensor]:
"""PyTorch implementation of prespective projection for 3D Gaussians.
Args:
means: Gaussian means in camera coordinate system. [C, N, 3].
covars: Gaussian covariances in camera coordinate system. [C, N, 3, 3].
Ks: Camera intrinsics. [C, 3, 3].
width: Image width.
height: Image height.
Returns:
A tuple:
- **means2d**: Projected means. [C, N, 2].
- **cov2d**: Projected covariances. [C, N, 2, 2].
"""
C, N, _ = means.shape

fx = Ks[..., 0, 0, None] # [C, 1]
fy = Ks[..., 1, 1, None] # [C, 1]

O = torch.zeros((C, N), device=means.device, dtype=means.dtype)
J = torch.stack(
[fx, O, O, O, fy, O], dim=-1
).reshape(C, N, 2, 3)

cov2d = torch.einsum("...ij,...jk,...kl->...il", J, covars, J.transpose(-1, -2))
means2d = torch.einsum("cij,cnj->cni", Ks[:, :2, :3], means) # [C, N, 2]
return means2d, cov2d # [C, N, 2], [C, N, 2, 2]


def _world_to_cam(
means: Tensor, # [N, 3]
covars: Tensor, # [N, 3, 3]
Expand Down Expand Up @@ -142,6 +179,7 @@ def _fully_fused_projection(
near_plane: float = 0.01,
far_plane: float = 1e10,
calc_compensations: bool = False,
ortho: bool = False
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Optional[Tensor]]:
"""PyTorch implementation of `gsplat.cuda._wrapper.fully_fused_projection()`
Expand All @@ -151,7 +189,12 @@ def _fully_fused_projection(
arguments. Not all arguments are supported.
"""
means_c, covars_c = _world_to_cam(means, covars, viewmats)
means2d, covars2d = _persp_proj(means_c, covars_c, Ks, width, height)

if ortho:
means2d, covars2d = _ortho_proj(means_c, covars_c, Ks, width, height)
else:
means2d, covars2d = _persp_proj(means_c, covars_c, Ks, width, height)

det_orig = (
covars2d[..., 0, 0] * covars2d[..., 1, 1]
- covars2d[..., 0, 1] * covars2d[..., 1, 0]
Expand Down
34 changes: 26 additions & 8 deletions gsplat/cuda/_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,13 @@ def quat_scale_to_covar_preci(
return covars if compute_covar else None, precis if compute_preci else None


def persp_proj(
def proj(
means: Tensor, # [C, N, 3]
covars: Tensor, # [C, N, 3, 3]
Ks: Tensor, # [C, 3, 3]
width: int,
height: int,
ortho: bool
) -> Tuple[Tensor, Tensor]:
"""Perspective projection on Gaussians.
Expand All @@ -106,7 +107,7 @@ def persp_proj(
means = means.contiguous()
covars = covars.contiguous()
Ks = Ks.contiguous()
return _PerspProj.apply(means, covars, Ks, width, height)
return _Proj.apply(means, covars, Ks, width, height, ortho)


def world_to_cam(
Expand Down Expand Up @@ -154,12 +155,13 @@ def fully_fused_projection(
packed: bool = False,
sparse_grad: bool = False,
calc_compensations: bool = False,
ortho: bool = False
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
"""Projects Gaussians to 2D.
This function fuse the process of computing covariances
(:func:`quat_scale_to_covar_preci()`), transforming to camera space (:func:`world_to_cam()`),
and perspective projection (:func:`persp_proj()`).
and projection (:func:`proj()`).
.. note::
Expand Down Expand Up @@ -255,6 +257,7 @@ def fully_fused_projection(
radius_clip,
sparse_grad,
calc_compensations,
ortho
)
else:
return _FullyFusedProjection.apply(
Expand All @@ -271,6 +274,7 @@ def fully_fused_projection(
far_plane,
radius_clip,
calc_compensations,
ortho
)


Expand Down Expand Up @@ -619,7 +623,7 @@ def backward(ctx, v_covars: Tensor, v_precis: Tensor):
return v_quats, v_scales, None, None, None


class _PerspProj(torch.autograd.Function):
class _Proj(torch.autograd.Function):
"""Perspective fully_fused_projection on Gaussians."""

@staticmethod
Expand All @@ -630,26 +634,30 @@ def forward(
Ks: Tensor, # [C, 3, 3]
width: int,
height: int,
ortho: bool
) -> Tuple[Tensor, Tensor]:
means2d, covars2d = _make_lazy_cuda_func("persp_proj_fwd")(
means, covars, Ks, width, height
means2d, covars2d = _make_lazy_cuda_func("proj_fwd")(
means, covars, Ks, width, height, ortho
)
ctx.save_for_backward(means, covars, Ks)
ctx.width = width
ctx.height = height
ctx.ortho = ortho
return means2d, covars2d

@staticmethod
def backward(ctx, v_means2d: Tensor, v_covars2d: Tensor):
means, covars, Ks = ctx.saved_tensors
width = ctx.width
height = ctx.height
v_means, v_covars = _make_lazy_cuda_func("persp_proj_bwd")(
ortho = ctx.ortho
v_means, v_covars = _make_lazy_cuda_func("proj_bwd")(
means,
covars,
Ks,
width,
height,
ortho,
v_means2d.contiguous(),
v_covars2d.contiguous(),
)
Expand Down Expand Up @@ -713,6 +721,7 @@ def forward(
far_plane: float,
radius_clip: float,
calc_compensations: bool,
ortho: bool
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
# "covars" and {"quats", "scales"} are mutually exclusive
radii, means2d, depths, conics, compensations = _make_lazy_cuda_func(
Expand All @@ -731,6 +740,7 @@ def forward(
far_plane,
radius_clip,
calc_compensations,
ortho
)
if not calc_compensations:
compensations = None
Expand All @@ -740,6 +750,7 @@ def forward(
ctx.width = width
ctx.height = height
ctx.eps2d = eps2d
ctx.ortho = ortho

return radii, means2d, depths, conics, compensations

Expand All @@ -759,6 +770,7 @@ def backward(ctx, v_radii, v_means2d, v_depths, v_conics, v_compensations):
width = ctx.width
height = ctx.height
eps2d = ctx.eps2d
ortho = ctx.ortho
if v_compensations is not None:
v_compensations = v_compensations.contiguous()
v_means, v_covars, v_quats, v_scales, v_viewmats = _make_lazy_cuda_func(
Expand All @@ -773,6 +785,7 @@ def backward(ctx, v_radii, v_means2d, v_depths, v_conics, v_compensations):
width,
height,
eps2d,
ortho,
radii,
conics,
compensations,
Expand Down Expand Up @@ -959,6 +972,7 @@ def forward(
radius_clip: float,
sparse_grad: bool,
calc_compensations: bool,
ortho: bool
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
(
indptr,
Expand All @@ -983,6 +997,7 @@ def forward(
far_plane,
radius_clip,
calc_compensations,
ortho
)
if not calc_compensations:
compensations = None
Expand All @@ -1002,7 +1017,8 @@ def forward(
ctx.height = height
ctx.eps2d = eps2d
ctx.sparse_grad = sparse_grad

ctx.ortho = ortho

return camera_ids, gaussian_ids, radii, means2d, depths, conics, compensations

@staticmethod
Expand Down Expand Up @@ -1032,6 +1048,7 @@ def backward(
height = ctx.height
eps2d = ctx.eps2d
sparse_grad = ctx.sparse_grad
ortho = ctx.ortho

if v_compensations is not None:
v_compensations = v_compensations.contiguous()
Expand All @@ -1047,6 +1064,7 @@ def backward(
width,
height,
eps2d,
ortho,
camera_ids,
gaussian_ids,
conics,
Expand Down
16 changes: 11 additions & 5 deletions gsplat/cuda/csrc/bindings.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,20 +47,22 @@ std::tuple<torch::Tensor, torch::Tensor> quat_scale_to_covar_preci_bwd_tensor(
const bool triu
);

std::tuple<torch::Tensor, torch::Tensor> persp_proj_fwd_tensor(
std::tuple<torch::Tensor, torch::Tensor> proj_fwd_tensor(
const torch::Tensor &means, // [C, N, 3]
const torch::Tensor &covars, // [C, N, 3, 3]
const torch::Tensor &Ks, // [C, 3, 3]
const uint32_t width,
const uint32_t height
const uint32_t height,
const bool ortho
);

std::tuple<torch::Tensor, torch::Tensor> persp_proj_bwd_tensor(
std::tuple<torch::Tensor, torch::Tensor> proj_bwd_tensor(
const torch::Tensor &means, // [C, N, 3]
const torch::Tensor &covars, // [C, N, 3, 3]
const torch::Tensor &Ks, // [C, 3, 3]
const uint32_t width,
const uint32_t height,
const bool ortho,
const torch::Tensor &v_means2d, // [C, N, 2]
const torch::Tensor &v_covars2d // [C, N, 2, 2]
);
Expand Down Expand Up @@ -101,7 +103,8 @@ fully_fused_projection_fwd_tensor(
const float near_plane,
const float far_plane,
const float radius_clip,
const bool calc_compensations
const bool calc_compensations,
const bool ortho
);

std::tuple<
Expand All @@ -121,6 +124,7 @@ fully_fused_projection_bwd_tensor(
const uint32_t image_width,
const uint32_t image_height,
const float eps2d,
const bool ortho,
// fwd outputs
const torch::Tensor &radii, // [C, N]
const torch::Tensor &conics, // [C, N, 3]
Expand Down Expand Up @@ -261,7 +265,8 @@ fully_fused_projection_packed_fwd_tensor(
const float near_plane,
const float far_plane,
const float radius_clip,
const bool calc_compensations
const bool calc_compensations,
const bool ortho
);

std::tuple<
Expand All @@ -281,6 +286,7 @@ fully_fused_projection_packed_bwd_tensor(
const uint32_t image_width,
const uint32_t image_height,
const float eps2d,
const bool ortho,
// fwd outputs
const torch::Tensor &camera_ids, // [nnz]
const torch::Tensor &gaussian_ids, // [nnz]
Expand Down
4 changes: 2 additions & 2 deletions gsplat/cuda/csrc/ext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
&gsplat::quat_scale_to_covar_preci_bwd_tensor
);

m.def("persp_proj_fwd", &gsplat::persp_proj_fwd_tensor);
m.def("persp_proj_bwd", &gsplat::persp_proj_bwd_tensor);
m.def("proj_fwd", &gsplat::proj_fwd_tensor);
m.def("proj_bwd", &gsplat::proj_bwd_tensor);

m.def("world_to_cam_fwd", &gsplat::world_to_cam_fwd_tensor);
m.def("world_to_cam_bwd", &gsplat::world_to_cam_bwd_tensor);
Expand Down
Loading

0 comments on commit 9fbf1e2

Please sign in to comment.