Skip to content

Commit

Permalink
Merge pull request #21 from vye16/zhuoyang/bindings
Browse files Browse the repository at this point in the history
Zhuoyang & Maturk Pybindings Update
  • Loading branch information
maturk authored Oct 3, 2023
2 parents d9e6499 + e179799 commit 986c66c
Show file tree
Hide file tree
Showing 15 changed files with 402 additions and 71 deletions.
66 changes: 55 additions & 11 deletions diff_rast/_torch_impl.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Pure PyTorch implementations of various functions"""

import torch
import torch.nn.functional as F
import struct
from jaxtyping import Float
from torch import Tensor

Expand Down Expand Up @@ -155,17 +155,27 @@ def scale_rot_to_cov3d(scale: Tensor, glob_scale: float, quat: Tensor) -> Tensor


def project_cov3d_ewa(
mean3d: Tensor, cov3d: Tensor, viewmat: Tensor, fx: float, fy: float
mean3d: Tensor,
cov3d: Tensor,
viewmat: Tensor,
fx: float,
fy: float,
tan_fovx: float,
tan_fovy: float,
) -> Tensor:
assert mean3d.shape[-1] == 3, mean3d.shape
assert cov3d.shape[-2:] == (3, 3), cov3d.shape
assert viewmat.shape[-2:] == (4, 4), viewmat.shape
W = viewmat[..., :3, :3] # (..., 3, 3)
p = viewmat[..., :3, 3] # (..., 3)
t = torch.matmul(W, mean3d[..., None])[..., 0] + p # (..., 3)
raise NotImplementedError(
"Need to incorporate changes from this commit: 85e76e1c8b8e102145922f561800a74262ceb196!"
)

lim_x = 1.3 * torch.tensor([tan_fovx], device=mean3d.device)
lim_y = 1.3 * torch.tensor([tan_fovy], device=mean3d.device)

t[..., 0] = t[..., 2] * torch.min(lim_x, torch.max(-lim_x, t[..., 0] / t[..., 2]))
t[..., 1] = t[..., 2] * torch.min(lim_y, torch.max(-lim_y, t[..., 1] / t[..., 2]))

rz = 1.0 / t[..., 2] # (...,)
rz2 = rz**2 # (...,)
J = torch.stack(
Expand All @@ -178,8 +188,8 @@ def project_cov3d_ewa(
T = J @ W # (..., 2, 3)
cov2d = T @ cov3d @ T.transpose(-1, -2) # (..., 2, 2)
# add a little blur along axes and (TODO save upper triangular elements)
cov2d[..., 0, 0] = cov2d[..., 0, 0] + 0.1
cov2d[..., 1, 1] = cov2d[..., 1, 1] + 0.1
cov2d[..., 0, 0] = cov2d[..., 0, 0] + 0.3
cov2d[..., 1, 1] = cov2d[..., 1, 1] + 0.3
return cov2d


Expand Down Expand Up @@ -215,11 +225,11 @@ def project_pix(mat, p, img_size, eps=1e-6):
return torch.stack([u, v], dim=-1)


def clip_near_plane(p, viewmat, thresh=0.1):
def clip_near_plane(p, viewmat, clip_thresh=0.01):
R = viewmat[..., :3, :3]
T = viewmat[..., :3, 3]
p_view = torch.matmul(R, p[..., None])[..., 0] + T
return p_view, p_view[..., 2] < thresh
return p_view, p_view[..., 2] < clip_thresh


def get_tile_bbox(pix_center, pix_radius, tile_bounds, BLOCK_X=16, BLOCK_Y=16):
Expand Down Expand Up @@ -259,10 +269,13 @@ def project_gaussians_forward(
fy,
img_size,
tile_bounds,
clip_thresh=0.01,
):
p_view, is_close = clip_near_plane(means3d, viewmat)
tan_fovx = 0.5 * img_size[1] / fx
tan_fovy = 0.5 * img_size[0] / fy
p_view, is_close = clip_near_plane(means3d, viewmat, clip_thresh)
cov3d = scale_rot_to_cov3d(scales, glob_scale, quats)
cov2d = project_cov3d_ewa(means3d, cov3d, viewmat, fx, fy)
cov2d = project_cov3d_ewa(means3d, cov3d, viewmat, fx, fy, tan_fovx, tan_fovy)
conic, radius, det_valid = compute_cov2d_bounds(cov2d)
center = project_pix(projmat, means3d, img_size)
tile_min, tile_max = get_tile_bbox(center, radius, tile_bounds)
Expand All @@ -278,3 +291,34 @@ def project_gaussians_forward(
conics = conic

return cov3d, xys, depths, radii, conics, num_tiles_hit, mask


def map_gaussian_to_intersects(
num_points, xys, depths, radii, cum_tiles_hit, tile_bounds
):
num_intersects = cum_tiles_hit[-1]
isect_ids = torch.zeros(num_intersects, dtype=torch.int64, device=xys.device)
gaussian_ids = torch.zeros(num_intersects, dtype=torch.int32, device=xys.device)

for idx in range(num_points):
if radii[idx] <= 0:
break

tile_min, tile_max = get_tile_bbox(xys[idx], radii[idx], tile_bounds)

cur_idx = 0 if idx == 0 else cum_tiles_hit[idx - 1]

# Get raw byte representation of the float value at the given index
raw_bytes = struct.pack("f", depths[idx])

# Interpret those bytes as an int32_t
depth_id_n = struct.unpack("i", raw_bytes)[0]

for i in range(tile_min[1], tile_max[1]):
for j in range(tile_min[0], tile_max[0]):
tile_id = i * tile_bounds[0] + j
isect_ids[cur_idx] = (tile_id << 32) | depth_id_n
gaussian_ids[cur_idx] = idx
cur_idx += 1

return isect_ids, gaussian_ids
2 changes: 2 additions & 0 deletions diff_rast/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,5 @@ def call_cuda(*args, **kwargs):
project_gaussians_backward = _make_lazy_cuda_func("project_gaussians_backward")
compute_sh_forward = _make_lazy_cuda_func("compute_sh_forward")
compute_sh_backward = _make_lazy_cuda_func("compute_sh_backward")
compute_cumulative_intersects = _make_lazy_cuda_func("compute_cumulative_intersects")
map_gaussian_to_intersects = _make_lazy_cuda_func("map_gaussian_to_intersects")
2 changes: 1 addition & 1 deletion diff_rast/cuda/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,4 @@ target_include_directories(check_serial_backward PRIVATE
)
target_include_directories(check_serial_forward PRIVATE
${PROJECT_SOURCE_DIR}/third_party/glm
)
)
8 changes: 4 additions & 4 deletions diff_rast/cuda/csrc/backward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

namespace cg = cooperative_groups;

template<int CHANNELS>
template <int CHANNELS>
__global__ void rasterize_backward_kernel(
const dim3 tile_bounds,
const dim3 img_size,
Expand Down Expand Up @@ -49,7 +49,8 @@ __global__ void rasterize_backward_kernel(
float T_final = final_Ts[pix_id];
float T = T_final;
// the contribution from gaussians behind the current one
float S[CHANNELS] = {0.f}; // TODO: this currently doesn't match the channel count input.
float S[CHANNELS] = {
0.f}; // TODO: this currently doesn't match the channel count input.
// S[0] = 0.0;
// S[1] = 0.0;
// S[2] = 0.0;
Expand Down Expand Up @@ -96,7 +97,6 @@ __global__ void rasterize_backward_kernel(
S[c] += rgbs[CHANNELS * g + c] * fac;
}


// v_alpha = (rgb.x * T - S.x * ra) * v_out.x
// + (rgb.y * T - S.y * ra) * v_out.y
// + (rgb.z * T - S.z * ra) * v_out.z;
Expand Down Expand Up @@ -146,7 +146,7 @@ void rasterize_backward_impl(
float *v_opacity

) {
rasterize_backward_kernel<3> <<<tile_bounds, block>>>(
rasterize_backward_kernel<3><<<tile_bounds, block>>>(
tile_bounds,
img_size,
gaussians_ids_sorted,
Expand Down
77 changes: 75 additions & 2 deletions diff_rast/cuda/csrc/bindings.cu
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
#include <cstdio>
#include <cub/cub.cuh>
#include <cub/device/device_radix_sort.cuh>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_runtime_api.h>
Expand All @@ -29,8 +31,9 @@ __global__ void compute_cov2d_bounds_forward_kernel(
float3 conic;
float radius;
float3 cov2d{
(float)covs2d[index], (float)covs2d[index + 1], (float)covs2d[index + 2]
};
(float)covs2d[index],
(float)covs2d[index + 1],
(float)covs2d[index + 2]};
compute_cov2d_bounds(cov2d, conic, radius);
conics[index] = conic.x;
conics[index + 1] = conic.y;
Expand Down Expand Up @@ -253,3 +256,73 @@ project_gaussians_backward_tensor(

return std::make_tuple(v_cov2d, v_cov3d, v_mean3d, v_scale, v_quat);
}

std::tuple<torch::Tensor, torch::Tensor> compute_cumulative_intersects_tensor(
const int num_points, torch::Tensor &num_tiles_hit
) {
// ref:
// https://nvlabs.github.io/cub/structcub_1_1_device_scan.html#a9416ac1ea26f9fde669d83ddc883795a
// allocate sum workspace
CHECK_INPUT(num_tiles_hit);

torch::Tensor cum_tiles_hit = torch::zeros(
{num_points}, num_tiles_hit.options().dtype(torch::kInt32)
);

int32_t num_intersects;
compute_cumulative_intersects(
num_points,
num_tiles_hit.contiguous().data_ptr<int32_t>(),
num_intersects,
cum_tiles_hit.contiguous().data_ptr<int32_t>()
);

return std::make_tuple(
torch::tensor(
num_intersects, num_tiles_hit.options().dtype(torch::kInt32)
),
cum_tiles_hit
);
}

std::tuple<torch::Tensor, torch::Tensor> map_gaussian_to_intersects_tensor(
const int num_points,
torch::Tensor &xys,
torch::Tensor &depths,
torch::Tensor &radii,
torch::Tensor &cum_tiles_hit,
const std::tuple<int, int, int> tile_bounds
) {
CHECK_INPUT(xys);
CHECK_INPUT(depths);
CHECK_INPUT(radii);
CHECK_INPUT(cum_tiles_hit);

dim3 tile_bounds_dim3;
tile_bounds_dim3.x = std::get<0>(tile_bounds);
tile_bounds_dim3.y = std::get<1>(tile_bounds);
tile_bounds_dim3.z = std::get<2>(tile_bounds);

int32_t num_intersects = cum_tiles_hit[num_points - 1].item<int32_t>();

torch::Tensor gaussian_ids_unsorted =
torch::zeros({num_intersects}, xys.options().dtype(torch::kInt32));
torch::Tensor isect_ids_unsorted =
torch::zeros({num_intersects}, xys.options().dtype(torch::kInt64));

map_gaussian_to_intersects<<<
(num_points + N_THREADS - 1) / N_THREADS,
N_THREADS>>>(
num_points,
(float2 *)xys.contiguous().data_ptr<float>(),
depths.contiguous().data_ptr<float>(),
radii.contiguous().data_ptr<int32_t>(),
cum_tiles_hit.contiguous().data_ptr<int32_t>(),
tile_bounds_dim3,
// Outputs.
isect_ids_unsorted.contiguous().data_ptr<int64_t>(),
gaussian_ids_unsorted.contiguous().data_ptr<int32_t>()
);

return std::make_tuple(isect_ids_unsorted, gaussian_ids_unsorted);
}
13 changes: 13 additions & 0 deletions diff_rast/cuda/csrc/bindings.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,16 @@ project_gaussians_backward_tensor(
torch::Tensor &v_xy,
torch::Tensor &v_conic
);

std::tuple<torch::Tensor, torch::Tensor> compute_cumulative_intersects_tensor(
const int num_points, torch::Tensor &num_tiles_hit
);

std::tuple<torch::Tensor, torch::Tensor> map_gaussian_to_intersects_tensor(
const int num_points,
torch::Tensor &xys,
torch::Tensor &depths,
torch::Tensor &radii,
torch::Tensor &cum_tiles_hit,
const std::tuple<int, int, int> tile_bounds
);
2 changes: 2 additions & 0 deletions diff_rast/cuda/csrc/ext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("project_gaussians_backward", &project_gaussians_backward_tensor);
m.def("compute_sh_forward", &compute_sh_forward_tensor);
m.def("compute_sh_backward", &compute_sh_backward_tensor);
m.def("compute_cumulative_intersects", &compute_cumulative_intersects_tensor);
m.def("map_gaussian_to_intersects", &map_gaussian_to_intersects_tensor);
}
11 changes: 11 additions & 0 deletions diff_rast/cuda/csrc/forward.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,14 @@ __host__ __device__ float3 project_cov3d_ewa(
__host__ __device__ void scale_rot_to_cov3d(
const float3 scale, const float glob_scale, const float4 quat, float *cov3d
);

__global__ void map_gaussian_to_intersects(
const int num_points,
const float2 *xys,
const float *depths,
const int *radii,
const int32_t *cum_tiles_hit,
const dim3 tile_bounds,
int64_t *isect_ids,
int32_t *gaussian_ids
);
Loading

0 comments on commit 986c66c

Please sign in to comment.