Skip to content

Commit

Permalink
Compute density compensation for screen space blurring (#117)
Browse files Browse the repository at this point in the history
  • Loading branch information
jb-ye authored Feb 8, 2024
1 parent 9fffa1a commit bb01f14
Show file tree
Hide file tree
Showing 7 changed files with 85 additions and 25 deletions.
38 changes: 27 additions & 11 deletions gsplat/_torch_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch.nn.functional as F
from jaxtyping import Float
from torch import Tensor
from typing import Tuple


def compute_sh_color(
Expand Down Expand Up @@ -116,15 +117,15 @@ def quat_to_rotmat(quat: Tensor) -> Tensor:
w, x, y, z = torch.unbind(F.normalize(quat, dim=-1), dim=-1)
mat = torch.stack(
[
1 - 2 * (y ** 2 + z ** 2),
1 - 2 * (y**2 + z**2),
2 * (x * y - w * z),
2 * (x * z + w * y),
2 * (x * y + w * z),
1 - 2 * (x ** 2 + z ** 2),
1 - 2 * (x**2 + z**2),
2 * (y * z - w * x),
2 * (x * z - w * y),
2 * (y * z + w * x),
1 - 2 * (x ** 2 + y ** 2),
1 - 2 * (x**2 + y**2),
],
dim=-1,
)
Expand All @@ -149,7 +150,7 @@ def project_cov3d_ewa(
fy: float,
tan_fovx: float,
tan_fovy: float,
) -> Tensor:
) -> Tuple[Tensor, Tensor]:
assert mean3d.shape[-1] == 3, mean3d.shape
assert cov3d.shape[-2:] == (3, 3), cov3d.shape
assert viewmat.shape[-2:] == (4, 4), viewmat.shape
Expand All @@ -158,7 +159,7 @@ def project_cov3d_ewa(
t = torch.einsum("...ij,...j->...i", W, mean3d) + p # (..., 3)

rz = 1.0 / t[..., 2] # (...,)
rz2 = rz ** 2 # (...,)
rz2 = rz**2 # (...,)

lim_x = 1.3 * torch.tensor([tan_fovx], device=mean3d.device)
lim_y = 1.3 * torch.tensor([tan_fovy], device=mean3d.device)
Expand All @@ -174,9 +175,12 @@ def project_cov3d_ewa(
T = torch.matmul(J, W) # (..., 2, 3)
cov2d = torch.einsum("...ij,...jk,...kl->...il", T, cov3d, T.transpose(-1, -2))
# add a little blur along axes and (TODO save upper triangular elements)
det_orig = cov2d[..., 0, 0] * cov2d[..., 1, 1] - cov2d[..., 0, 1] * cov2d[..., 0, 1]
cov2d[..., 0, 0] = cov2d[..., 0, 0] + 0.3
cov2d[..., 1, 1] = cov2d[..., 1, 1] + 0.3
return cov2d[..., :2, :2]
det_blur = cov2d[..., 0, 0] * cov2d[..., 1, 1] - cov2d[..., 0, 1] * cov2d[..., 0, 1]
compensation = torch.sqrt(torch.clamp(det_orig / det_blur, min=0))
return cov2d[..., :2, :2], compensation.detach()


def compute_cov2d_bounds(cov2d_mat: Tensor):
Expand All @@ -198,8 +202,8 @@ def compute_cov2d_bounds(cov2d_mat: Tensor):
dim=-1,
) # (..., 3)
b = (cov2d[..., 0, 0] + cov2d[..., 1, 1]) / 2 # (...,)
v1 = b + torch.sqrt(torch.clamp(b ** 2 - det, min=0.1)) # (...,)
v2 = b - torch.sqrt(torch.clamp(b ** 2 - det, min=0.1)) # (...,)
v1 = b + torch.sqrt(torch.clamp(b**2 - det, min=0.1)) # (...,)
v2 = b - torch.sqrt(torch.clamp(b**2 - det, min=0.1)) # (...,)
radius = torch.ceil(3.0 * torch.sqrt(torch.max(v1, v2))) # (...,)
radius_all = torch.zeros(*cov2d_mat.shape[:-2], device=cov2d_mat.device)
conic_all = torch.zeros(*cov2d_mat.shape[:-2], 3, device=cov2d_mat.device)
Expand Down Expand Up @@ -272,7 +276,9 @@ def project_gaussians_forward(
tan_fovy = 0.5 * img_size[1] / fy
p_view, is_close = clip_near_plane(means3d, viewmat, clip_thresh)
cov3d = scale_rot_to_cov3d(scales, glob_scale, quats)
cov2d = project_cov3d_ewa(means3d, cov3d, viewmat, fx, fy, tan_fovx, tan_fovy)
cov2d, compensation = project_cov3d_ewa(
means3d, cov3d, viewmat, fx, fy, tan_fovx, tan_fovy
)
conic, radius, det_valid = compute_cov2d_bounds(cov2d)
xys = project_pix(fullmat, means3d, img_size, (cx, cy))
tile_min, tile_max = get_tile_bbox(xys, radius, tile_bounds)
Expand All @@ -290,14 +296,25 @@ def project_gaussians_forward(
xys = torch.where(~mask[..., None], 0, xys)
cov3d = torch.where(~mask[..., None, None], 0, cov3d)
cov2d = torch.where(~mask[..., None, None], 0, cov2d)
compensation = torch.where(~mask, 0, compensation)
num_tiles_hit = torch.where(~mask, 0, num_tiles_hit)
depths = torch.where(~mask, 0, depths)

i, j = torch.triu_indices(3, 3)
cov3d_triu = cov3d[..., i, j]
i, j = torch.triu_indices(2, 2)
cov2d_triu = cov2d[..., i, j]
return cov3d_triu, cov2d_triu, xys, depths, radii, conic, num_tiles_hit, mask
return (
cov3d_triu,
cov2d_triu,
xys,
depths,
radii,
conic,
compensation,
num_tiles_hit,
mask,
)


def map_gaussian_to_intersects(
Expand Down Expand Up @@ -339,7 +356,6 @@ def get_tile_bin_edges(num_intersects, isect_ids_sorted, tile_bounds):
)

for idx in range(num_intersects):

cur_tile_idx = isect_ids_sorted[idx] >> 32

if idx == 0:
Expand Down
6 changes: 5 additions & 1 deletion gsplat/cuda/csrc/bindings.cu
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ std::tuple<
torch::Tensor,
torch::Tensor,
torch::Tensor,
torch::Tensor,
torch::Tensor>
project_gaussians_forward_tensor(
const int num_points,
Expand Down Expand Up @@ -162,6 +163,8 @@ project_gaussians_forward_tensor(
torch::zeros({num_points}, means3d.options().dtype(torch::kInt32));
torch::Tensor conics_d =
torch::zeros({num_points, 3}, means3d.options().dtype(torch::kFloat32));
torch::Tensor compensation_d =
torch::zeros({num_points}, means3d.options().dtype(torch::kFloat32));
torch::Tensor num_tiles_hit_d =
torch::zeros({num_points}, means3d.options().dtype(torch::kInt32));

Expand All @@ -185,11 +188,12 @@ project_gaussians_forward_tensor(
depths_d.contiguous().data_ptr<float>(),
radii_d.contiguous().data_ptr<int>(),
(float3 *)conics_d.contiguous().data_ptr<float>(),
compensation_d.contiguous().data_ptr<float>(),
num_tiles_hit_d.contiguous().data_ptr<int32_t>()
);

return std::make_tuple(
cov3d_d, xys_d, depths_d, radii_d, conics_d, num_tiles_hit_d
cov3d_d, xys_d, depths_d, radii_d, conics_d, compensation_d, num_tiles_hit_d
);
}

Expand Down
1 change: 1 addition & 0 deletions gsplat/cuda/csrc/bindings.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ std::tuple<
torch::Tensor,
torch::Tensor,
torch::Tensor,
torch::Tensor,
torch::Tensor>
project_gaussians_forward_tensor(
const int num_points,
Expand Down
24 changes: 19 additions & 5 deletions gsplat/cuda/csrc/forward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ __global__ void project_gaussians_forward_kernel(
float* __restrict__ depths,
int* __restrict__ radii,
float3* __restrict__ conics,
float* __restrict__ compensation,
int32_t* __restrict__ num_tiles_hit
) {
unsigned idx = cg::this_grid().thread_rank(); // idx of thread within grid
Expand Down Expand Up @@ -61,8 +62,11 @@ __global__ void project_gaussians_forward_kernel(
float cy = intrins.w;
float tan_fovx = 0.5 * img_size.x / fx;
float tan_fovy = 0.5 * img_size.y / fy;
float3 cov2d = project_cov3d_ewa(
p_world, cur_cov3d, viewmat, fx, fy, tan_fovx, tan_fovy
float3 cov2d;
float comp;
project_cov3d_ewa(
p_world, cur_cov3d, viewmat, fx, fy, tan_fovx, tan_fovy,
cov2d, comp
);
// printf("cov2d %d, %.2f %.2f %.2f\n", idx, cov2d.x, cov2d.y, cov2d.z);

Expand All @@ -88,6 +92,7 @@ __global__ void project_gaussians_forward_kernel(
depths[idx] = p_view.z;
radii[idx] = (int)radius;
xys[idx] = center;
compensation[idx] = comp;
// printf(
// "point %d x %.2f y %.2f z %.2f, radius %d, # tiles %d, tile_min %d
// %d, tile_max %d %d\n", idx, center.x, center.y, depths[idx],
Expand Down Expand Up @@ -372,14 +377,16 @@ __global__ void rasterize_forward(
}

// device helper to approximate projected 2d cov from 3d mean and cov
__device__ float3 project_cov3d_ewa(
__device__ void project_cov3d_ewa(
const float3& __restrict__ mean3d,
const float* __restrict__ cov3d,
const float* __restrict__ viewmat,
const float fx,
const float fy,
const float tan_fovx,
const float tan_fovy
const float tan_fovy,
float3 &cov2d,
float &compensation
) {
// clip the
// we expect row major matrices as input, glm uses column major
Expand Down Expand Up @@ -437,7 +444,14 @@ __device__ float3 project_cov3d_ewa(
glm::mat3 cov = T * V * glm::transpose(T);

// add a little blur along axes and save upper triangular elements
return make_float3(float(cov[0][0]) + 0.3f, float(cov[0][1]), float(cov[1][1]) + 0.3f);
// and compute the density compensation factor due to the blurs
float c00 = cov[0][0], c11 = cov[1][1], c01 = cov[0][1];
float det_orig = c00 * c11 - c01 * c01;
cov2d.x = c00 + 0.3f;
cov2d.y = c01;
cov2d.z = c11 + 0.3f;
float det_blur = cov2d.x * cov2d.z - cov2d.y * cov2d.y;
compensation = std::sqrt(std::max(0.f, det_orig / det_blur));
}

// device helper to get 3D covariance from scale and quat parameters
Expand Down
7 changes: 5 additions & 2 deletions gsplat/cuda/csrc/forward.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ __global__ void project_gaussians_forward_kernel(
float* __restrict__ depths,
int* __restrict__ radii,
float3* __restrict__ conics,
float* __restrict__ compensation,
int32_t* __restrict__ num_tiles_hit
);

Expand Down Expand Up @@ -57,14 +58,16 @@ __global__ void nd_rasterize_forward(
);

// device helper to approximate projected 2d cov from 3d mean and cov
__device__ float3 project_cov3d_ewa(
__device__ void project_cov3d_ewa(
const float3 &mean3d,
const float *cov3d,
const float *viewmat,
const float fx,
const float fy,
const float tan_fovx,
const float tan_fovy
const float tan_fovy,
float3 &cov2d,
float &comp
);

// device helper to get 3D covariance from scale and quat parameters
Expand Down
19 changes: 15 additions & 4 deletions gsplat/project_gaussians.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def project_gaussians(
img_width: int,
tile_bounds: Tuple[int, int, int],
clip_thresh: float = 0.01,
) -> Tuple[Tensor, Tensor, Tensor, Tensor, int, Tensor]:
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
"""This function projects 3D gaussians to 2D using the EWA splatting method for gaussian splatting.
Note:
Expand All @@ -47,12 +47,13 @@ def project_gaussians(
clip_thresh (float): minimum z depth threshold.
Returns:
A tuple of {Tensor, Tensor, Tensor, Tensor, Tensor, Tensor}:
A tuple of {Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor}:
- **xys** (Tensor): x,y locations of 2D gaussian projections.
- **depths** (Tensor): z depth of gaussians.
- **radii** (Tensor): radii of 2D gaussian projections.
- **conics** (Tensor): conic parameters for 2D gaussian.
- **compensation** (Tensor): the density compensation for blurring 2D kernel
- **num_tiles_hit** (Tensor): number of tiles hit per gaussian.
- **cov3d** (Tensor): 3D covariances.
"""
Expand Down Expand Up @@ -105,6 +106,7 @@ def forward(
depths,
radii,
conics,
compensation,
num_tiles_hit,
) = _C.project_gaussians_forward(
num_points,
Expand Down Expand Up @@ -146,10 +148,19 @@ def forward(
conics,
)

return (xys, depths, radii, conics, num_tiles_hit, cov3d)
return (xys, depths, radii, conics, compensation, num_tiles_hit, cov3d)

@staticmethod
def backward(ctx, v_xys, v_depths, v_radii, v_conics, v_num_tiles_hit, v_cov3d):
def backward(
ctx,
v_xys,
v_depths,
v_radii,
v_conics,
v_compensation,
v_num_tiles_hit,
v_cov3d,
):
(
means3d,
scales,
Expand Down
15 changes: 13 additions & 2 deletions tests/test_project_gaussians.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,15 @@ def test_project_gaussians_forward():
BLOCK_X, BLOCK_Y = 16, 16
tile_bounds = (W + BLOCK_X - 1) // BLOCK_X, (H + BLOCK_Y - 1) // BLOCK_Y, 1

(cov3d, xys, depths, radii, conics, num_tiles_hit,) = _C.project_gaussians_forward(
(
cov3d,
xys,
depths,
radii,
conics,
compensation,
num_tiles_hit,
) = _C.project_gaussians_forward(
num_points,
means3d,
scales,
Expand All @@ -93,6 +101,7 @@ def test_project_gaussians_forward():
_depths,
_radii,
_conics,
_compensation,
_num_tiles_hit,
_masks,
) = _torch_impl.project_gaussians_forward(
Expand All @@ -114,6 +123,7 @@ def test_project_gaussians_forward():
check_close(depths, _depths)
check_close(radii, _radii)
check_close(conics, _conics)
check_close(compensation, _compensation)
check_close(num_tiles_hit, _num_tiles_hit)
print("passed project_gaussians_forward test")

Expand Down Expand Up @@ -156,6 +166,7 @@ def test_project_gaussians_backward():
radii,
conics,
_,
_,
masks,
) = _torch_impl.project_gaussians_forward(
means3d,
Expand Down Expand Up @@ -219,7 +230,7 @@ def project_cov3d_ewa_partial(mean3d, cov3d):
i, j = torch.triu_indices(3, 3)
cov3d_mat[..., i, j] = cov3d
cov3d_mat[..., [1, 2, 2], [0, 0, 1]] = cov3d[..., [1, 2, 4]]
cov2d = _torch_impl.project_cov3d_ewa(
cov2d, _ = _torch_impl.project_cov3d_ewa(
mean3d, cov3d_mat, viewmat, fx, fy, tan_fovx, tan_fovy
)
ii, jj = torch.triu_indices(2, 2)
Expand Down

0 comments on commit bb01f14

Please sign in to comment.