Skip to content

Commit

Permalink
Added orthographic projection (#349)
Browse files Browse the repository at this point in the history
  • Loading branch information
VladislavZavadskyy authored Aug 21, 2024
1 parent 45d196a commit 5ec2670
Show file tree
Hide file tree
Showing 15 changed files with 415 additions and 108 deletions.
2 changes: 1 addition & 1 deletion docs/source/apis/utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Below are the basic functions that supports the rasterization.

.. autofunction:: quat_scale_to_covar_preci

.. autofunction:: persp_proj
.. autofunction:: proj

.. autofunction:: fully_fused_projection

Expand Down
4 changes: 2 additions & 2 deletions gsplat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
fully_fused_projection,
isect_offset_encode,
isect_tiles,
persp_proj,
proj,
quat_scale_to_covar_preci,
rasterize_to_indices_in_range,
rasterize_to_pixels,
Expand Down Expand Up @@ -110,7 +110,7 @@ def get_tile_bin_edges(*args, **kwargs):
"spherical_harmonics",
"isect_offset_encode",
"isect_tiles",
"persp_proj",
"proj",
"fully_fused_projection",
"quat_scale_to_covar_preci",
"rasterize_to_pixels",
Expand Down
47 changes: 45 additions & 2 deletions gsplat/cuda/_torch_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def _persp_proj(
width: int,
height: int,
) -> Tuple[Tensor, Tensor]:
"""PyTorch implementation of prespective projection for 3D Gaussians.
"""PyTorch implementation of perspective projection for 3D Gaussians.
Args:
means: Gaussian means in camera coordinate system. [C, N, 3].
Expand Down Expand Up @@ -106,6 +106,43 @@ def _persp_proj(
return means2d, cov2d # [C, N, 2], [C, N, 2, 2]


def _ortho_proj(
means: Tensor, # [C, N, 3]
covars: Tensor, # [C, N, 3, 3]
Ks: Tensor, # [C, 3, 3]
width: int,
height: int,
) -> Tuple[Tensor, Tensor]:
"""PyTorch implementation of orthographic projection for 3D Gaussians.
Args:
means: Gaussian means in camera coordinate system. [C, N, 3].
covars: Gaussian covariances in camera coordinate system. [C, N, 3, 3].
Ks: Camera intrinsics. [C, 3, 3].
width: Image width.
height: Image height.
Returns:
A tuple:
- **means2d**: Projected means. [C, N, 2].
- **cov2d**: Projected covariances. [C, N, 2, 2].
"""
C, N, _ = means.shape

fx = Ks[..., 0, 0, None] # [C, 1]
fy = Ks[..., 1, 1, None] # [C, 1]

O = torch.zeros((C, 1), device=means.device, dtype=means.dtype)
J = torch.stack([fx, O, O, O, fy, O], dim=-1).reshape(C, 1, 2, 3).repeat(1, N, 1, 1)

cov2d = torch.einsum("...ij,...jk,...kl->...il", J, covars, J.transpose(-1, -2))
means2d = (
means[..., :2] * Ks[:, None, [0, 1], [0, 1]] + Ks[:, None, [0, 1], [2, 2]]
) # [C, N, 2]
return means2d, cov2d # [C, N, 2], [C, N, 2, 2]


def _world_to_cam(
means: Tensor, # [N, 3]
covars: Tensor, # [N, 3, 3]
Expand Down Expand Up @@ -142,6 +179,7 @@ def _fully_fused_projection(
near_plane: float = 0.01,
far_plane: float = 1e10,
calc_compensations: bool = False,
ortho: bool = False,
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Optional[Tensor]]:
"""PyTorch implementation of `gsplat.cuda._wrapper.fully_fused_projection()`
Expand All @@ -151,7 +189,12 @@ def _fully_fused_projection(
arguments. Not all arguments are supported.
"""
means_c, covars_c = _world_to_cam(means, covars, viewmats)
means2d, covars2d = _persp_proj(means_c, covars_c, Ks, width, height)

if ortho:
means2d, covars2d = _ortho_proj(means_c, covars_c, Ks, width, height)
else:
means2d, covars2d = _persp_proj(means_c, covars_c, Ks, width, height)

det_orig = (
covars2d[..., 0, 0] * covars2d[..., 1, 1]
- covars2d[..., 0, 1] * covars2d[..., 1, 0]
Expand Down
66 changes: 59 additions & 7 deletions gsplat/cuda/_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Callable, Optional, Tuple
import warnings

import torch
from torch import Tensor
Expand Down Expand Up @@ -85,6 +86,38 @@ def persp_proj(
height: int,
) -> Tuple[Tensor, Tensor]:
"""Perspective projection on Gaussians.
DEPRECATED: please use `proj` with `ortho=False` instead.
Args:
means: Gaussian means. [C, N, 3]
covars: Gaussian covariances. [C, N, 3, 3]
Ks: Camera intrinsics. [C, 3, 3]
width: Image width.
height: Image height.
Returns:
A tuple:
- **Projected means**. [C, N, 2]
- **Projected covariances**. [C, N, 2, 2]
"""
warnings.warn(
"persp_proj is deprecated and will be removed in a future release. "
"Use proj with ortho=False instead.",
DeprecationWarning,
)
return proj(means, covars, Ks, width, height, ortho=False)


def proj(
means: Tensor, # [C, N, 3]
covars: Tensor, # [C, N, 3, 3]
Ks: Tensor, # [C, 3, 3]
width: int,
height: int,
ortho: bool,
) -> Tuple[Tensor, Tensor]:
"""Projection of Gaussians (perspective or orthographic).
Args:
means: Gaussian means. [C, N, 3]
Expand All @@ -106,7 +139,7 @@ def persp_proj(
means = means.contiguous()
covars = covars.contiguous()
Ks = Ks.contiguous()
return _PerspProj.apply(means, covars, Ks, width, height)
return _Proj.apply(means, covars, Ks, width, height, ortho)


def world_to_cam(
Expand Down Expand Up @@ -154,12 +187,13 @@ def fully_fused_projection(
packed: bool = False,
sparse_grad: bool = False,
calc_compensations: bool = False,
ortho: bool = False,
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
"""Projects Gaussians to 2D.
This function fuse the process of computing covariances
(:func:`quat_scale_to_covar_preci()`), transforming to camera space (:func:`world_to_cam()`),
and perspective projection (:func:`persp_proj()`).
and projection (:func:`proj()`).
.. note::
Expand Down Expand Up @@ -255,6 +289,7 @@ def fully_fused_projection(
radius_clip,
sparse_grad,
calc_compensations,
ortho,
)
else:
return _FullyFusedProjection.apply(
Expand All @@ -271,6 +306,7 @@ def fully_fused_projection(
far_plane,
radius_clip,
calc_compensations,
ortho,
)


Expand Down Expand Up @@ -619,7 +655,7 @@ def backward(ctx, v_covars: Tensor, v_precis: Tensor):
return v_quats, v_scales, None, None, None


class _PerspProj(torch.autograd.Function):
class _Proj(torch.autograd.Function):
"""Perspective fully_fused_projection on Gaussians."""

@staticmethod
Expand All @@ -630,30 +666,34 @@ def forward(
Ks: Tensor, # [C, 3, 3]
width: int,
height: int,
ortho: bool,
) -> Tuple[Tensor, Tensor]:
means2d, covars2d = _make_lazy_cuda_func("persp_proj_fwd")(
means, covars, Ks, width, height
means2d, covars2d = _make_lazy_cuda_func("proj_fwd")(
means, covars, Ks, width, height, ortho
)
ctx.save_for_backward(means, covars, Ks)
ctx.width = width
ctx.height = height
ctx.ortho = ortho
return means2d, covars2d

@staticmethod
def backward(ctx, v_means2d: Tensor, v_covars2d: Tensor):
means, covars, Ks = ctx.saved_tensors
width = ctx.width
height = ctx.height
v_means, v_covars = _make_lazy_cuda_func("persp_proj_bwd")(
ortho = ctx.ortho
v_means, v_covars = _make_lazy_cuda_func("proj_bwd")(
means,
covars,
Ks,
width,
height,
ortho,
v_means2d.contiguous(),
v_covars2d.contiguous(),
)
return v_means, v_covars, None, None, None
return v_means, v_covars, None, None, None, None


class _WorldToCam(torch.autograd.Function):
Expand Down Expand Up @@ -713,6 +753,7 @@ def forward(
far_plane: float,
radius_clip: float,
calc_compensations: bool,
ortho: bool,
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
# "covars" and {"quats", "scales"} are mutually exclusive
radii, means2d, depths, conics, compensations = _make_lazy_cuda_func(
Expand All @@ -731,6 +772,7 @@ def forward(
far_plane,
radius_clip,
calc_compensations,
ortho,
)
if not calc_compensations:
compensations = None
Expand All @@ -740,6 +782,7 @@ def forward(
ctx.width = width
ctx.height = height
ctx.eps2d = eps2d
ctx.ortho = ortho

return radii, means2d, depths, conics, compensations

Expand All @@ -759,6 +802,7 @@ def backward(ctx, v_radii, v_means2d, v_depths, v_conics, v_compensations):
width = ctx.width
height = ctx.height
eps2d = ctx.eps2d
ortho = ctx.ortho
if v_compensations is not None:
v_compensations = v_compensations.contiguous()
v_means, v_covars, v_quats, v_scales, v_viewmats = _make_lazy_cuda_func(
Expand All @@ -773,6 +817,7 @@ def backward(ctx, v_radii, v_means2d, v_depths, v_conics, v_compensations):
width,
height,
eps2d,
ortho,
radii,
conics,
compensations,
Expand Down Expand Up @@ -806,6 +851,7 @@ def backward(ctx, v_radii, v_means2d, v_depths, v_conics, v_compensations):
None,
None,
None,
None,
)


Expand Down Expand Up @@ -959,6 +1005,7 @@ def forward(
radius_clip: float,
sparse_grad: bool,
calc_compensations: bool,
ortho: bool,
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
(
indptr,
Expand All @@ -983,6 +1030,7 @@ def forward(
far_plane,
radius_clip,
calc_compensations,
ortho,
)
if not calc_compensations:
compensations = None
Expand All @@ -1002,6 +1050,7 @@ def forward(
ctx.height = height
ctx.eps2d = eps2d
ctx.sparse_grad = sparse_grad
ctx.ortho = ortho

return camera_ids, gaussian_ids, radii, means2d, depths, conics, compensations

Expand Down Expand Up @@ -1032,6 +1081,7 @@ def backward(
height = ctx.height
eps2d = ctx.eps2d
sparse_grad = ctx.sparse_grad
ortho = ctx.ortho

if v_compensations is not None:
v_compensations = v_compensations.contiguous()
Expand All @@ -1047,6 +1097,7 @@ def backward(
width,
height,
eps2d,
ortho,
camera_ids,
gaussian_ids,
conics,
Expand Down Expand Up @@ -1121,6 +1172,7 @@ def backward(
None,
None,
None,
None,
)


Expand Down
16 changes: 11 additions & 5 deletions gsplat/cuda/csrc/bindings.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,20 +47,22 @@ std::tuple<torch::Tensor, torch::Tensor> quat_scale_to_covar_preci_bwd_tensor(
const bool triu
);

std::tuple<torch::Tensor, torch::Tensor> persp_proj_fwd_tensor(
std::tuple<torch::Tensor, torch::Tensor> proj_fwd_tensor(
const torch::Tensor &means, // [C, N, 3]
const torch::Tensor &covars, // [C, N, 3, 3]
const torch::Tensor &Ks, // [C, 3, 3]
const uint32_t width,
const uint32_t height
const uint32_t height,
const bool ortho
);

std::tuple<torch::Tensor, torch::Tensor> persp_proj_bwd_tensor(
std::tuple<torch::Tensor, torch::Tensor> proj_bwd_tensor(
const torch::Tensor &means, // [C, N, 3]
const torch::Tensor &covars, // [C, N, 3, 3]
const torch::Tensor &Ks, // [C, 3, 3]
const uint32_t width,
const uint32_t height,
const bool ortho,
const torch::Tensor &v_means2d, // [C, N, 2]
const torch::Tensor &v_covars2d // [C, N, 2, 2]
);
Expand Down Expand Up @@ -101,7 +103,8 @@ fully_fused_projection_fwd_tensor(
const float near_plane,
const float far_plane,
const float radius_clip,
const bool calc_compensations
const bool calc_compensations,
const bool ortho
);

std::tuple<
Expand All @@ -121,6 +124,7 @@ fully_fused_projection_bwd_tensor(
const uint32_t image_width,
const uint32_t image_height,
const float eps2d,
const bool ortho,
// fwd outputs
const torch::Tensor &radii, // [C, N]
const torch::Tensor &conics, // [C, N, 3]
Expand Down Expand Up @@ -261,7 +265,8 @@ fully_fused_projection_packed_fwd_tensor(
const float near_plane,
const float far_plane,
const float radius_clip,
const bool calc_compensations
const bool calc_compensations,
const bool ortho
);

std::tuple<
Expand All @@ -281,6 +286,7 @@ fully_fused_projection_packed_bwd_tensor(
const uint32_t image_width,
const uint32_t image_height,
const float eps2d,
const bool ortho,
// fwd outputs
const torch::Tensor &camera_ids, // [nnz]
const torch::Tensor &gaussian_ids, // [nnz]
Expand Down
Loading

0 comments on commit 5ec2670

Please sign in to comment.