Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Zhuoyang & Maturk Pybindings Update #21

Merged
merged 11 commits into from
Oct 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 55 additions & 11 deletions diff_rast/_torch_impl.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Pure PyTorch implementations of various functions"""

import torch
import torch.nn.functional as F
import struct
from jaxtyping import Float
from torch import Tensor

Expand Down Expand Up @@ -155,17 +155,27 @@ def scale_rot_to_cov3d(scale: Tensor, glob_scale: float, quat: Tensor) -> Tensor


def project_cov3d_ewa(
mean3d: Tensor, cov3d: Tensor, viewmat: Tensor, fx: float, fy: float
mean3d: Tensor,
cov3d: Tensor,
viewmat: Tensor,
fx: float,
fy: float,
tan_fovx: float,
tan_fovy: float,
) -> Tensor:
assert mean3d.shape[-1] == 3, mean3d.shape
assert cov3d.shape[-2:] == (3, 3), cov3d.shape
assert viewmat.shape[-2:] == (4, 4), viewmat.shape
W = viewmat[..., :3, :3] # (..., 3, 3)
p = viewmat[..., :3, 3] # (..., 3)
t = torch.matmul(W, mean3d[..., None])[..., 0] + p # (..., 3)
raise NotImplementedError(
"Need to incorporate changes from this commit: 85e76e1c8b8e102145922f561800a74262ceb196!"
)

lim_x = 1.3 * torch.tensor([tan_fovx], device=mean3d.device)
lim_y = 1.3 * torch.tensor([tan_fovy], device=mean3d.device)

t[..., 0] = t[..., 2] * torch.min(lim_x, torch.max(-lim_x, t[..., 0] / t[..., 2]))
t[..., 1] = t[..., 2] * torch.min(lim_y, torch.max(-lim_y, t[..., 1] / t[..., 2]))

rz = 1.0 / t[..., 2] # (...,)
rz2 = rz**2 # (...,)
J = torch.stack(
Expand All @@ -178,8 +188,8 @@ def project_cov3d_ewa(
T = J @ W # (..., 2, 3)
cov2d = T @ cov3d @ T.transpose(-1, -2) # (..., 2, 2)
# add a little blur along axes and (TODO save upper triangular elements)
cov2d[..., 0, 0] = cov2d[..., 0, 0] + 0.1
cov2d[..., 1, 1] = cov2d[..., 1, 1] + 0.1
cov2d[..., 0, 0] = cov2d[..., 0, 0] + 0.3
cov2d[..., 1, 1] = cov2d[..., 1, 1] + 0.3
return cov2d


Expand Down Expand Up @@ -215,11 +225,11 @@ def project_pix(mat, p, img_size, eps=1e-6):
return torch.stack([u, v], dim=-1)


def clip_near_plane(p, viewmat, thresh=0.1):
def clip_near_plane(p, viewmat, clip_thresh=0.01):
R = viewmat[..., :3, :3]
T = viewmat[..., :3, 3]
p_view = torch.matmul(R, p[..., None])[..., 0] + T
return p_view, p_view[..., 2] < thresh
return p_view, p_view[..., 2] < clip_thresh


def get_tile_bbox(pix_center, pix_radius, tile_bounds, BLOCK_X=16, BLOCK_Y=16):
Expand Down Expand Up @@ -259,10 +269,13 @@ def project_gaussians_forward(
fy,
img_size,
tile_bounds,
clip_thresh=0.01,
):
p_view, is_close = clip_near_plane(means3d, viewmat)
tan_fovx = 0.5 * img_size[1] / fx
tan_fovy = 0.5 * img_size[0] / fy
p_view, is_close = clip_near_plane(means3d, viewmat, clip_thresh)
cov3d = scale_rot_to_cov3d(scales, glob_scale, quats)
cov2d = project_cov3d_ewa(means3d, cov3d, viewmat, fx, fy)
cov2d = project_cov3d_ewa(means3d, cov3d, viewmat, fx, fy, tan_fovx, tan_fovy)
conic, radius, det_valid = compute_cov2d_bounds(cov2d)
center = project_pix(projmat, means3d, img_size)
tile_min, tile_max = get_tile_bbox(center, radius, tile_bounds)
Expand All @@ -278,3 +291,34 @@ def project_gaussians_forward(
conics = conic

return cov3d, xys, depths, radii, conics, num_tiles_hit, mask


def map_gaussian_to_intersects(
num_points, xys, depths, radii, cum_tiles_hit, tile_bounds
):
num_intersects = cum_tiles_hit[-1]
isect_ids = torch.zeros(num_intersects, dtype=torch.int64, device=xys.device)
gaussian_ids = torch.zeros(num_intersects, dtype=torch.int32, device=xys.device)

for idx in range(num_points):
if radii[idx] <= 0:
break

tile_min, tile_max = get_tile_bbox(xys[idx], radii[idx], tile_bounds)

cur_idx = 0 if idx == 0 else cum_tiles_hit[idx - 1]

# Get raw byte representation of the float value at the given index
raw_bytes = struct.pack("f", depths[idx])

# Interpret those bytes as an int32_t
depth_id_n = struct.unpack("i", raw_bytes)[0]

for i in range(tile_min[1], tile_max[1]):
for j in range(tile_min[0], tile_max[0]):
tile_id = i * tile_bounds[0] + j
isect_ids[cur_idx] = (tile_id << 32) | depth_id_n
gaussian_ids[cur_idx] = idx
cur_idx += 1

return isect_ids, gaussian_ids
2 changes: 2 additions & 0 deletions diff_rast/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,5 @@ def call_cuda(*args, **kwargs):
project_gaussians_backward = _make_lazy_cuda_func("project_gaussians_backward")
compute_sh_forward = _make_lazy_cuda_func("compute_sh_forward")
compute_sh_backward = _make_lazy_cuda_func("compute_sh_backward")
compute_cumulative_intersects = _make_lazy_cuda_func("compute_cumulative_intersects")
map_gaussian_to_intersects = _make_lazy_cuda_func("map_gaussian_to_intersects")
2 changes: 1 addition & 1 deletion diff_rast/cuda/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,4 @@ target_include_directories(check_serial_backward PRIVATE
)
target_include_directories(check_serial_forward PRIVATE
${PROJECT_SOURCE_DIR}/third_party/glm
)
)
8 changes: 4 additions & 4 deletions diff_rast/cuda/csrc/backward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

namespace cg = cooperative_groups;

template<int CHANNELS>
template <int CHANNELS>
__global__ void rasterize_backward_kernel(
const dim3 tile_bounds,
const dim3 img_size,
Expand Down Expand Up @@ -49,7 +49,8 @@ __global__ void rasterize_backward_kernel(
float T_final = final_Ts[pix_id];
float T = T_final;
// the contribution from gaussians behind the current one
float S[CHANNELS] = {0.f}; // TODO: this currently doesn't match the channel count input.
float S[CHANNELS] = {
0.f}; // TODO: this currently doesn't match the channel count input.
// S[0] = 0.0;
// S[1] = 0.0;
// S[2] = 0.0;
Expand Down Expand Up @@ -96,7 +97,6 @@ __global__ void rasterize_backward_kernel(
S[c] += rgbs[CHANNELS * g + c] * fac;
}


// v_alpha = (rgb.x * T - S.x * ra) * v_out.x
// + (rgb.y * T - S.y * ra) * v_out.y
// + (rgb.z * T - S.z * ra) * v_out.z;
Expand Down Expand Up @@ -146,7 +146,7 @@ void rasterize_backward_impl(
float *v_opacity

) {
rasterize_backward_kernel<3> <<<tile_bounds, block>>>(
rasterize_backward_kernel<3><<<tile_bounds, block>>>(
tile_bounds,
img_size,
gaussians_ids_sorted,
Expand Down
77 changes: 75 additions & 2 deletions diff_rast/cuda/csrc/bindings.cu
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
#include <cstdio>
#include <cub/cub.cuh>
#include <cub/device/device_radix_sort.cuh>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_runtime_api.h>
Expand All @@ -29,8 +31,9 @@ __global__ void compute_cov2d_bounds_forward_kernel(
float3 conic;
float radius;
float3 cov2d{
(float)covs2d[index], (float)covs2d[index + 1], (float)covs2d[index + 2]
};
(float)covs2d[index],
(float)covs2d[index + 1],
(float)covs2d[index + 2]};
compute_cov2d_bounds(cov2d, conic, radius);
conics[index] = conic.x;
conics[index + 1] = conic.y;
Expand Down Expand Up @@ -253,3 +256,73 @@ project_gaussians_backward_tensor(

return std::make_tuple(v_cov2d, v_cov3d, v_mean3d, v_scale, v_quat);
}

std::tuple<torch::Tensor, torch::Tensor> compute_cumulative_intersects_tensor(
const int num_points, torch::Tensor &num_tiles_hit
) {
// ref:
// https://nvlabs.github.io/cub/structcub_1_1_device_scan.html#a9416ac1ea26f9fde669d83ddc883795a
// allocate sum workspace
CHECK_INPUT(num_tiles_hit);

torch::Tensor cum_tiles_hit = torch::zeros(
{num_points}, num_tiles_hit.options().dtype(torch::kInt32)
);

int32_t num_intersects;
compute_cumulative_intersects(
num_points,
num_tiles_hit.contiguous().data_ptr<int32_t>(),
num_intersects,
cum_tiles_hit.contiguous().data_ptr<int32_t>()
);

return std::make_tuple(
torch::tensor(
num_intersects, num_tiles_hit.options().dtype(torch::kInt32)
),
cum_tiles_hit
);
}

std::tuple<torch::Tensor, torch::Tensor> map_gaussian_to_intersects_tensor(
const int num_points,
torch::Tensor &xys,
torch::Tensor &depths,
torch::Tensor &radii,
torch::Tensor &cum_tiles_hit,
const std::tuple<int, int, int> tile_bounds
) {
CHECK_INPUT(xys);
CHECK_INPUT(depths);
CHECK_INPUT(radii);
CHECK_INPUT(cum_tiles_hit);

dim3 tile_bounds_dim3;
tile_bounds_dim3.x = std::get<0>(tile_bounds);
tile_bounds_dim3.y = std::get<1>(tile_bounds);
tile_bounds_dim3.z = std::get<2>(tile_bounds);

int32_t num_intersects = cum_tiles_hit[num_points - 1].item<int32_t>();

torch::Tensor gaussian_ids_unsorted =
torch::zeros({num_intersects}, xys.options().dtype(torch::kInt32));
torch::Tensor isect_ids_unsorted =
torch::zeros({num_intersects}, xys.options().dtype(torch::kInt64));

map_gaussian_to_intersects<<<
(num_points + N_THREADS - 1) / N_THREADS,
N_THREADS>>>(
num_points,
(float2 *)xys.contiguous().data_ptr<float>(),
depths.contiguous().data_ptr<float>(),
radii.contiguous().data_ptr<int32_t>(),
cum_tiles_hit.contiguous().data_ptr<int32_t>(),
tile_bounds_dim3,
// Outputs.
isect_ids_unsorted.contiguous().data_ptr<int64_t>(),
gaussian_ids_unsorted.contiguous().data_ptr<int32_t>()
);

return std::make_tuple(isect_ids_unsorted, gaussian_ids_unsorted);
}
13 changes: 13 additions & 0 deletions diff_rast/cuda/csrc/bindings.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,16 @@ project_gaussians_backward_tensor(
torch::Tensor &v_xy,
torch::Tensor &v_conic
);

std::tuple<torch::Tensor, torch::Tensor> compute_cumulative_intersects_tensor(
const int num_points, torch::Tensor &num_tiles_hit
);

std::tuple<torch::Tensor, torch::Tensor> map_gaussian_to_intersects_tensor(
const int num_points,
torch::Tensor &xys,
torch::Tensor &depths,
torch::Tensor &radii,
torch::Tensor &cum_tiles_hit,
const std::tuple<int, int, int> tile_bounds
);
2 changes: 2 additions & 0 deletions diff_rast/cuda/csrc/ext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("project_gaussians_backward", &project_gaussians_backward_tensor);
m.def("compute_sh_forward", &compute_sh_forward_tensor);
m.def("compute_sh_backward", &compute_sh_backward_tensor);
m.def("compute_cumulative_intersects", &compute_cumulative_intersects_tensor);
m.def("map_gaussian_to_intersects", &map_gaussian_to_intersects_tensor);
}
11 changes: 11 additions & 0 deletions diff_rast/cuda/csrc/forward.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,14 @@ __host__ __device__ float3 project_cov3d_ewa(
__host__ __device__ void scale_rot_to_cov3d(
const float3 scale, const float glob_scale, const float4 quat, float *cov3d
);

__global__ void map_gaussian_to_intersects(
const int num_points,
const float2 *xys,
const float *depths,
const int *radii,
const int32_t *cum_tiles_hit,
const dim3 tile_bounds,
int64_t *isect_ids,
int32_t *gaussian_ids
);
Loading