Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added orthographic projection #349

Merged
merged 3 commits into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/apis/utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Below are the basic functions that supports the rasterization.

.. autofunction:: quat_scale_to_covar_preci

.. autofunction:: persp_proj
.. autofunction:: proj

.. autofunction:: fully_fused_projection

Expand Down
4 changes: 2 additions & 2 deletions gsplat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
fully_fused_projection,
isect_offset_encode,
isect_tiles,
persp_proj,
proj,
quat_scale_to_covar_preci,
rasterize_to_indices_in_range,
rasterize_to_pixels,
Expand Down Expand Up @@ -110,7 +110,7 @@ def get_tile_bin_edges(*args, **kwargs):
"spherical_harmonics",
"isect_offset_encode",
"isect_tiles",
"persp_proj",
"proj",
"fully_fused_projection",
"quat_scale_to_covar_preci",
"rasterize_to_pixels",
Expand Down
47 changes: 45 additions & 2 deletions gsplat/cuda/_torch_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def _persp_proj(
width: int,
height: int,
) -> Tuple[Tensor, Tensor]:
"""PyTorch implementation of prespective projection for 3D Gaussians.
"""PyTorch implementation of perspective projection for 3D Gaussians.

Args:
means: Gaussian means in camera coordinate system. [C, N, 3].
Expand Down Expand Up @@ -106,6 +106,43 @@ def _persp_proj(
return means2d, cov2d # [C, N, 2], [C, N, 2, 2]


def _ortho_proj(
means: Tensor, # [C, N, 3]
covars: Tensor, # [C, N, 3, 3]
Ks: Tensor, # [C, 3, 3]
width: int,
height: int,
) -> Tuple[Tensor, Tensor]:
"""PyTorch implementation of orthographic projection for 3D Gaussians.

Args:
means: Gaussian means in camera coordinate system. [C, N, 3].
covars: Gaussian covariances in camera coordinate system. [C, N, 3, 3].
Ks: Camera intrinsics. [C, 3, 3].
width: Image width.
height: Image height.

Returns:
A tuple:

- **means2d**: Projected means. [C, N, 2].
- **cov2d**: Projected covariances. [C, N, 2, 2].
"""
C, N, _ = means.shape

fx = Ks[..., 0, 0, None] # [C, 1]
fy = Ks[..., 1, 1, None] # [C, 1]

O = torch.zeros((C, 1), device=means.device, dtype=means.dtype)
J = torch.stack([fx, O, O, O, fy, O], dim=-1).reshape(C, 1, 2, 3).repeat(1, N, 1, 1)

cov2d = torch.einsum("...ij,...jk,...kl->...il", J, covars, J.transpose(-1, -2))
means2d = (
means[..., :2] * Ks[:, None, [0, 1], [0, 1]] + Ks[:, None, [0, 1], [2, 2]]
) # [C, N, 2]
return means2d, cov2d # [C, N, 2], [C, N, 2, 2]


def _world_to_cam(
means: Tensor, # [N, 3]
covars: Tensor, # [N, 3, 3]
Expand Down Expand Up @@ -142,6 +179,7 @@ def _fully_fused_projection(
near_plane: float = 0.01,
far_plane: float = 1e10,
calc_compensations: bool = False,
ortho: bool = False,
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Optional[Tensor]]:
"""PyTorch implementation of `gsplat.cuda._wrapper.fully_fused_projection()`

Expand All @@ -151,7 +189,12 @@ def _fully_fused_projection(
arguments. Not all arguments are supported.
"""
means_c, covars_c = _world_to_cam(means, covars, viewmats)
means2d, covars2d = _persp_proj(means_c, covars_c, Ks, width, height)

if ortho:
means2d, covars2d = _ortho_proj(means_c, covars_c, Ks, width, height)
else:
means2d, covars2d = _persp_proj(means_c, covars_c, Ks, width, height)

det_orig = (
covars2d[..., 0, 0] * covars2d[..., 1, 1]
- covars2d[..., 0, 1] * covars2d[..., 1, 0]
Expand Down
66 changes: 59 additions & 7 deletions gsplat/cuda/_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Callable, Optional, Tuple
import warnings

import torch
from torch import Tensor
Expand Down Expand Up @@ -85,6 +86,38 @@ def persp_proj(
height: int,
) -> Tuple[Tensor, Tensor]:
"""Perspective projection on Gaussians.
DEPRECATED: please use `proj` with `ortho=False` instead.

Args:
means: Gaussian means. [C, N, 3]
covars: Gaussian covariances. [C, N, 3, 3]
Ks: Camera intrinsics. [C, 3, 3]
width: Image width.
height: Image height.

Returns:
A tuple:

- **Projected means**. [C, N, 2]
- **Projected covariances**. [C, N, 2, 2]
"""
warnings.warn(
"persp_proj is deprecated and will be removed in a future release. "
"Use proj with ortho=False instead.",
DeprecationWarning,
)
return proj(means, covars, Ks, width, height, ortho=False)


def proj(
means: Tensor, # [C, N, 3]
covars: Tensor, # [C, N, 3, 3]
Ks: Tensor, # [C, 3, 3]
width: int,
height: int,
ortho: bool,
) -> Tuple[Tensor, Tensor]:
"""Projection of Gaussians (perspective or orthographic).

Args:
means: Gaussian means. [C, N, 3]
Expand All @@ -106,7 +139,7 @@ def persp_proj(
means = means.contiguous()
covars = covars.contiguous()
Ks = Ks.contiguous()
return _PerspProj.apply(means, covars, Ks, width, height)
return _Proj.apply(means, covars, Ks, width, height, ortho)


def world_to_cam(
Expand Down Expand Up @@ -154,12 +187,13 @@ def fully_fused_projection(
packed: bool = False,
sparse_grad: bool = False,
calc_compensations: bool = False,
ortho: bool = False,
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
"""Projects Gaussians to 2D.

This function fuse the process of computing covariances
(:func:`quat_scale_to_covar_preci()`), transforming to camera space (:func:`world_to_cam()`),
and perspective projection (:func:`persp_proj()`).
and projection (:func:`proj()`).

.. note::

Expand Down Expand Up @@ -255,6 +289,7 @@ def fully_fused_projection(
radius_clip,
sparse_grad,
calc_compensations,
ortho,
)
else:
return _FullyFusedProjection.apply(
Expand All @@ -271,6 +306,7 @@ def fully_fused_projection(
far_plane,
radius_clip,
calc_compensations,
ortho,
)


Expand Down Expand Up @@ -619,7 +655,7 @@ def backward(ctx, v_covars: Tensor, v_precis: Tensor):
return v_quats, v_scales, None, None, None


class _PerspProj(torch.autograd.Function):
class _Proj(torch.autograd.Function):
"""Perspective fully_fused_projection on Gaussians."""

@staticmethod
Expand All @@ -630,30 +666,34 @@ def forward(
Ks: Tensor, # [C, 3, 3]
width: int,
height: int,
ortho: bool,
) -> Tuple[Tensor, Tensor]:
means2d, covars2d = _make_lazy_cuda_func("persp_proj_fwd")(
means, covars, Ks, width, height
means2d, covars2d = _make_lazy_cuda_func("proj_fwd")(
means, covars, Ks, width, height, ortho
)
ctx.save_for_backward(means, covars, Ks)
ctx.width = width
ctx.height = height
ctx.ortho = ortho
return means2d, covars2d

@staticmethod
def backward(ctx, v_means2d: Tensor, v_covars2d: Tensor):
means, covars, Ks = ctx.saved_tensors
width = ctx.width
height = ctx.height
v_means, v_covars = _make_lazy_cuda_func("persp_proj_bwd")(
ortho = ctx.ortho
v_means, v_covars = _make_lazy_cuda_func("proj_bwd")(
means,
covars,
Ks,
width,
height,
ortho,
v_means2d.contiguous(),
v_covars2d.contiguous(),
)
return v_means, v_covars, None, None, None
return v_means, v_covars, None, None, None, None


class _WorldToCam(torch.autograd.Function):
Expand Down Expand Up @@ -713,6 +753,7 @@ def forward(
far_plane: float,
radius_clip: float,
calc_compensations: bool,
ortho: bool,
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
# "covars" and {"quats", "scales"} are mutually exclusive
radii, means2d, depths, conics, compensations = _make_lazy_cuda_func(
Expand All @@ -731,6 +772,7 @@ def forward(
far_plane,
radius_clip,
calc_compensations,
ortho,
)
if not calc_compensations:
compensations = None
Expand All @@ -740,6 +782,7 @@ def forward(
ctx.width = width
ctx.height = height
ctx.eps2d = eps2d
ctx.ortho = ortho

return radii, means2d, depths, conics, compensations

Expand All @@ -759,6 +802,7 @@ def backward(ctx, v_radii, v_means2d, v_depths, v_conics, v_compensations):
width = ctx.width
height = ctx.height
eps2d = ctx.eps2d
ortho = ctx.ortho
if v_compensations is not None:
v_compensations = v_compensations.contiguous()
v_means, v_covars, v_quats, v_scales, v_viewmats = _make_lazy_cuda_func(
Expand All @@ -773,6 +817,7 @@ def backward(ctx, v_radii, v_means2d, v_depths, v_conics, v_compensations):
width,
height,
eps2d,
ortho,
radii,
conics,
compensations,
Expand Down Expand Up @@ -806,6 +851,7 @@ def backward(ctx, v_radii, v_means2d, v_depths, v_conics, v_compensations):
None,
None,
None,
None,
)


Expand Down Expand Up @@ -959,6 +1005,7 @@ def forward(
radius_clip: float,
sparse_grad: bool,
calc_compensations: bool,
ortho: bool,
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
(
indptr,
Expand All @@ -983,6 +1030,7 @@ def forward(
far_plane,
radius_clip,
calc_compensations,
ortho,
)
if not calc_compensations:
compensations = None
Expand All @@ -1002,6 +1050,7 @@ def forward(
ctx.height = height
ctx.eps2d = eps2d
ctx.sparse_grad = sparse_grad
ctx.ortho = ortho

return camera_ids, gaussian_ids, radii, means2d, depths, conics, compensations

Expand Down Expand Up @@ -1032,6 +1081,7 @@ def backward(
height = ctx.height
eps2d = ctx.eps2d
sparse_grad = ctx.sparse_grad
ortho = ctx.ortho

if v_compensations is not None:
v_compensations = v_compensations.contiguous()
Expand All @@ -1047,6 +1097,7 @@ def backward(
width,
height,
eps2d,
ortho,
camera_ids,
gaussian_ids,
conics,
Expand Down Expand Up @@ -1121,6 +1172,7 @@ def backward(
None,
None,
None,
None,
)


Expand Down
16 changes: 11 additions & 5 deletions gsplat/cuda/csrc/bindings.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,20 +47,22 @@ std::tuple<torch::Tensor, torch::Tensor> quat_scale_to_covar_preci_bwd_tensor(
const bool triu
);

std::tuple<torch::Tensor, torch::Tensor> persp_proj_fwd_tensor(
std::tuple<torch::Tensor, torch::Tensor> proj_fwd_tensor(
const torch::Tensor &means, // [C, N, 3]
const torch::Tensor &covars, // [C, N, 3, 3]
const torch::Tensor &Ks, // [C, 3, 3]
const uint32_t width,
const uint32_t height
const uint32_t height,
const bool ortho
);

std::tuple<torch::Tensor, torch::Tensor> persp_proj_bwd_tensor(
std::tuple<torch::Tensor, torch::Tensor> proj_bwd_tensor(
const torch::Tensor &means, // [C, N, 3]
const torch::Tensor &covars, // [C, N, 3, 3]
const torch::Tensor &Ks, // [C, 3, 3]
const uint32_t width,
const uint32_t height,
const bool ortho,
const torch::Tensor &v_means2d, // [C, N, 2]
const torch::Tensor &v_covars2d // [C, N, 2, 2]
);
Expand Down Expand Up @@ -101,7 +103,8 @@ fully_fused_projection_fwd_tensor(
const float near_plane,
const float far_plane,
const float radius_clip,
const bool calc_compensations
const bool calc_compensations,
const bool ortho
);

std::tuple<
Expand All @@ -121,6 +124,7 @@ fully_fused_projection_bwd_tensor(
const uint32_t image_width,
const uint32_t image_height,
const float eps2d,
const bool ortho,
// fwd outputs
const torch::Tensor &radii, // [C, N]
const torch::Tensor &conics, // [C, N, 3]
Expand Down Expand Up @@ -261,7 +265,8 @@ fully_fused_projection_packed_fwd_tensor(
const float near_plane,
const float far_plane,
const float radius_clip,
const bool calc_compensations
const bool calc_compensations,
const bool ortho
);

std::tuple<
Expand All @@ -281,6 +286,7 @@ fully_fused_projection_packed_bwd_tensor(
const uint32_t image_width,
const uint32_t image_height,
const float eps2d,
const bool ortho,
// fwd outputs
const torch::Tensor &camera_ids, // [nnz]
const torch::Tensor &gaussian_ids, // [nnz]
Expand Down
Loading
Loading