Skip to content

Commit

Permalink
support packed and sparse
Browse files Browse the repository at this point in the history
  • Loading branch information
jefequien committed Jul 15, 2024
1 parent 6d156d2 commit 9c06a24
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 8 deletions.
14 changes: 13 additions & 1 deletion gsplat/cuda/_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -958,6 +958,7 @@ def forward(
radii,
means2d,
depths,
normals,
conics,
compensations,
) = _make_lazy_cuda_func("fully_fused_projection_packed_fwd")(
Expand Down Expand Up @@ -994,7 +995,16 @@ def forward(
ctx.eps2d = eps2d
ctx.sparse_grad = sparse_grad

return camera_ids, gaussian_ids, radii, means2d, depths, conics, compensations
return (
camera_ids,
gaussian_ids,
radii,
means2d,
depths,
normals,
conics,
compensations,
)

@staticmethod
def backward(
Expand All @@ -1004,6 +1014,7 @@ def backward(
v_radii,
v_means2d,
v_depths,
v_normals,
v_conics,
v_compensations,
):
Expand Down Expand Up @@ -1044,6 +1055,7 @@ def backward(
compensations,
v_means2d.contiguous(),
v_depths.contiguous(),
v_normals.contiguous(),
v_conics.contiguous(),
v_compensations,
ctx.needs_input_grad[4], # viewmats_requires_grad
Expand Down
3 changes: 2 additions & 1 deletion gsplat/cuda/csrc/bindings.h
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ compute_sh_bwd_tensor(const uint32_t K, const uint32_t degrees_to_use,
* Packed Version
****************************************************************************************/
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor,
torch::Tensor, torch::Tensor, torch::Tensor>
torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
fully_fused_projection_packed_fwd_tensor(
const torch::Tensor &means, // [N, 3]
const at::optional<torch::Tensor> &covars, // [N, 6]
Expand Down Expand Up @@ -210,6 +210,7 @@ fully_fused_projection_packed_bwd_tensor(
// grad outputs
const torch::Tensor &v_means2d, // [nnz, 2]
const torch::Tensor &v_depths, // [nnz]
const torch::Tensor &v_normals, // [nnz, 3]
const torch::Tensor &v_conics, // [nnz, 3]
const at::optional<torch::Tensor> &v_compensations, // [nnz] optional
const bool viewmats_requires_grad, const bool sparse_grad);
Expand Down
2 changes: 1 addition & 1 deletion gsplat/cuda/csrc/fully_fused_projection_fwd.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ fully_fused_projection_fwd_kernel(const uint32_t C, const uint32_t N,
int32_t *__restrict__ radii, // [C, N]
T *__restrict__ means2d, // [C, N, 2]
T *__restrict__ depths, // [C, N]
T *__restrict__ normals, // [C, N, 3]
T *__restrict__ normals, // [C, N, 3]
T *__restrict__ conics, // [C, N, 3]
T *__restrict__ compensations // [C, N] optional
) {
Expand Down
17 changes: 16 additions & 1 deletion gsplat/cuda/csrc/fully_fused_projection_packed_bwd.cu
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ __global__ void fully_fused_projection_packed_bwd_kernel(
// grad outputs
const T *__restrict__ v_means2d, // [nnz, 2]
const T *__restrict__ v_depths, // [nnz]
const T *__restrict__ v_normals, // [nnz, 3]
const T *__restrict__ v_conics, // [nnz, 3]
const T *__restrict__ v_compensations, // [nnz] optional
const bool sparse_grad, // whether the outputs are in COO format [nnz, ...]
Expand Down Expand Up @@ -61,6 +62,7 @@ __global__ void fully_fused_projection_packed_bwd_kernel(

v_means2d += idx * 2;
v_depths += idx;
v_normals += idx * 3;
v_conics += idx * 3;

// vjp: compute the inverse of the 2d covariance
Expand Down Expand Up @@ -154,6 +156,11 @@ __global__ void fully_fused_projection_packed_bwd_kernel(
v_scales[0] = v_scale[0];
v_scales[1] = v_scale[1];
v_scales[2] = v_scale[2];

// add contribution from v_normals. Please check if this is correct.
mat3<T> v_R = quat_to_rotmat<T>(quat);
v_R[2] += glm::make_vec3(v_normals);
quat_to_rotmat_vjp<T>(quat, v_R, v_quat);
}
} else {
// write out results with dense layout
Expand Down Expand Up @@ -188,6 +195,12 @@ __global__ void fully_fused_projection_packed_bwd_kernel(
vec4<T> v_quat(0.f);
vec3<T> v_scale(0.f);
quat_scale_to_covar_vjp<T>(quat, scale, rotmat, v_covar, v_quat, v_scale);

// add contribution from v_normals. Please check if this is correct.
mat3<T> v_R = quat_to_rotmat<T>(quat);
v_R[2] += glm::make_vec3(v_normals);
quat_to_rotmat_vjp<T>(quat, v_R, v_quat);

warpSum(v_quat, warp_group_g);
warpSum(v_scale, warp_group_g);
if (warp_group_g.thread_rank() == 0) {
Expand Down Expand Up @@ -240,6 +253,7 @@ fully_fused_projection_packed_bwd_tensor(
// grad outputs
const torch::Tensor &v_means2d, // [nnz, 2]
const torch::Tensor &v_depths, // [nnz]
const torch::Tensor &v_normals, // [nnz, 3]
const torch::Tensor &v_conics, // [nnz, 3]
const at::optional<torch::Tensor> &v_compensations, // [nnz] optional
const bool viewmats_requires_grad, const bool sparse_grad) {
Expand All @@ -259,6 +273,7 @@ fully_fused_projection_packed_bwd_tensor(
CHECK_INPUT(conics);
CHECK_INPUT(v_means2d);
CHECK_INPUT(v_depths);
CHECK_INPUT(v_normals);
CHECK_INPUT(v_conics);
if (compensations.has_value()) {
CHECK_INPUT(compensations.value());
Expand Down Expand Up @@ -309,7 +324,7 @@ fully_fused_projection_packed_bwd_tensor(
compensations.has_value() ? compensations.value().data_ptr<float>()
: nullptr,
v_means2d.data_ptr<float>(), v_depths.data_ptr<float>(),
v_conics.data_ptr<float>(),
v_normals.data_ptr<float>(), v_conics.data_ptr<float>(),
v_compensations.has_value() ? v_compensations.value().data_ptr<float>()
: nullptr,
sparse_grad, v_means.data_ptr<float>(),
Expand Down
17 changes: 13 additions & 4 deletions gsplat/cuda/csrc/fully_fused_projection_packed_fwd.cu
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ __global__ void fully_fused_projection_packed_fwd_kernel(
int32_t *__restrict__ radii, // [nnz]
T *__restrict__ means2d, // [nnz, 2]
T *__restrict__ depths, // [nnz]
T *__restrict__ normals, // [nnz, 3]
T *__restrict__ conics, // [nnz, 3]
T *__restrict__ compensations // [nnz] optional
) {
Expand Down Expand Up @@ -75,6 +76,7 @@ __global__ void fully_fused_projection_packed_fwd_kernel(
mat2<T> covar2d;
vec2<T> mean2d;
mat2<T> covar2d_inv;
vec3<T> normal;
T compensation;
T det;
if (valid) {
Expand All @@ -92,6 +94,9 @@ __global__ void fully_fused_projection_packed_fwd_kernel(
quats += col_idx * 4;
scales += col_idx * 3;
quat_scale_to_covar_preci<T>(glm::make_vec4(quats), glm::make_vec3(scales), &covar, nullptr);

glm::mat3 rotmat = quat_to_rotmat(glm::make_vec4(quats));
normal = rotmat[2];
}
mat3<T> covar_c;
covar_world_to_cam(R, covar, covar_c);
Expand Down Expand Up @@ -163,6 +168,9 @@ __global__ void fully_fused_projection_packed_fwd_kernel(
means2d[thread_data * 2] = mean2d.x;
means2d[thread_data * 2 + 1] = mean2d.y;
depths[thread_data] = mean_c.z;
normals[thread_data * 3] = normal.x;
normals[thread_data * 3 + 1] = normal.y;
normals[thread_data * 3 + 2] = normal.z;
conics[thread_data * 3] = covar2d_inv[0][0];
conics[thread_data * 3 + 1] = covar2d_inv[0][1];
conics[thread_data * 3 + 2] = covar2d_inv[1][1];
Expand All @@ -183,7 +191,7 @@ __global__ void fully_fused_projection_packed_fwd_kernel(
}

std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor,
torch::Tensor, torch::Tensor, torch::Tensor>
torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
fully_fused_projection_packed_fwd_tensor(
const torch::Tensor &means, // [N, 3]
const at::optional<torch::Tensor> &covars, // [N, 6]
Expand Down Expand Up @@ -232,7 +240,7 @@ fully_fused_projection_packed_fwd_tensor(
viewmats.data_ptr<float>(), Ks.data_ptr<float>(), image_width, image_height,
eps2d, near_plane, far_plane, radius_clip, nullptr,
block_cnts.data_ptr<int32_t>(), nullptr, nullptr, nullptr, nullptr, nullptr,
nullptr, nullptr, nullptr);
nullptr, nullptr, nullptr, nullptr);
block_accum = torch::cumsum(block_cnts, 0, torch::kInt32);
nnz = block_accum[-1].item<int32_t>();
} else {
Expand All @@ -246,6 +254,7 @@ fully_fused_projection_packed_fwd_tensor(
torch::Tensor radii = torch::empty({nnz}, means.options().dtype(torch::kInt32));
torch::Tensor means2d = torch::empty({nnz, 2}, means.options());
torch::Tensor depths = torch::empty({nnz}, means.options());
torch::Tensor normals = torch::empty({nnz, 3}, means.options());
torch::Tensor conics = torch::empty({nnz, 3}, means.options());
torch::Tensor compensations;
if (calc_compensations) {
Expand All @@ -264,12 +273,12 @@ fully_fused_projection_packed_fwd_tensor(
nullptr, indptr.data_ptr<int32_t>(), camera_ids.data_ptr<int64_t>(),
gaussian_ids.data_ptr<int64_t>(), radii.data_ptr<int32_t>(),
means2d.data_ptr<float>(), depths.data_ptr<float>(),
conics.data_ptr<float>(),
normals.data_ptr<float>(), conics.data_ptr<float>(),
calc_compensations ? compensations.data_ptr<float>() : nullptr);
} else {
indptr.fill_(0);
}

return std::make_tuple(indptr, camera_ids, gaussian_ids, radii, means2d, depths,
conics, compensations);
normals, conics, compensations);
}
1 change: 1 addition & 0 deletions gsplat/rendering.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ def rasterization(
radii,
means2d,
depths,
normals,
conics,
compensations,
) = proj_results
Expand Down

0 comments on commit 9c06a24

Please sign in to comment.