diff --git a/diff_rast/_torch_impl.py b/diff_rast/_torch_impl.py index 5d2b25af8..f100d1624 100644 --- a/diff_rast/_torch_impl.py +++ b/diff_rast/_torch_impl.py @@ -1,7 +1,7 @@ """Pure PyTorch implementations of various functions""" - import torch import torch.nn.functional as F +import struct from jaxtyping import Float from torch import Tensor @@ -155,7 +155,13 @@ def scale_rot_to_cov3d(scale: Tensor, glob_scale: float, quat: Tensor) -> Tensor def project_cov3d_ewa( - mean3d: Tensor, cov3d: Tensor, viewmat: Tensor, fx: float, fy: float + mean3d: Tensor, + cov3d: Tensor, + viewmat: Tensor, + fx: float, + fy: float, + tan_fovx: float, + tan_fovy: float, ) -> Tensor: assert mean3d.shape[-1] == 3, mean3d.shape assert cov3d.shape[-2:] == (3, 3), cov3d.shape @@ -163,9 +169,13 @@ def project_cov3d_ewa( W = viewmat[..., :3, :3] # (..., 3, 3) p = viewmat[..., :3, 3] # (..., 3) t = torch.matmul(W, mean3d[..., None])[..., 0] + p # (..., 3) - raise NotImplementedError( - "Need to incorporate changes from this commit: 85e76e1c8b8e102145922f561800a74262ceb196!" - ) + + lim_x = 1.3 * torch.tensor([tan_fovx], device=mean3d.device) + lim_y = 1.3 * torch.tensor([tan_fovy], device=mean3d.device) + + t[..., 0] = t[..., 2] * torch.min(lim_x, torch.max(-lim_x, t[..., 0] / t[..., 2])) + t[..., 1] = t[..., 2] * torch.min(lim_y, torch.max(-lim_y, t[..., 1] / t[..., 2])) + rz = 1.0 / t[..., 2] # (...,) rz2 = rz**2 # (...,) J = torch.stack( @@ -178,8 +188,8 @@ def project_cov3d_ewa( T = J @ W # (..., 2, 3) cov2d = T @ cov3d @ T.transpose(-1, -2) # (..., 2, 2) # add a little blur along axes and (TODO save upper triangular elements) - cov2d[..., 0, 0] = cov2d[..., 0, 0] + 0.1 - cov2d[..., 1, 1] = cov2d[..., 1, 1] + 0.1 + cov2d[..., 0, 0] = cov2d[..., 0, 0] + 0.3 + cov2d[..., 1, 1] = cov2d[..., 1, 1] + 0.3 return cov2d @@ -215,11 +225,11 @@ def project_pix(mat, p, img_size, eps=1e-6): return torch.stack([u, v], dim=-1) -def clip_near_plane(p, viewmat, thresh=0.1): +def clip_near_plane(p, viewmat, clip_thresh=0.01): R = viewmat[..., :3, :3] T = viewmat[..., :3, 3] p_view = torch.matmul(R, p[..., None])[..., 0] + T - return p_view, p_view[..., 2] < thresh + return p_view, p_view[..., 2] < clip_thresh def get_tile_bbox(pix_center, pix_radius, tile_bounds, BLOCK_X=16, BLOCK_Y=16): @@ -259,10 +269,13 @@ def project_gaussians_forward( fy, img_size, tile_bounds, + clip_thresh=0.01, ): - p_view, is_close = clip_near_plane(means3d, viewmat) + tan_fovx = 0.5 * img_size[1] / fx + tan_fovy = 0.5 * img_size[0] / fy + p_view, is_close = clip_near_plane(means3d, viewmat, clip_thresh) cov3d = scale_rot_to_cov3d(scales, glob_scale, quats) - cov2d = project_cov3d_ewa(means3d, cov3d, viewmat, fx, fy) + cov2d = project_cov3d_ewa(means3d, cov3d, viewmat, fx, fy, tan_fovx, tan_fovy) conic, radius, det_valid = compute_cov2d_bounds(cov2d) center = project_pix(projmat, means3d, img_size) tile_min, tile_max = get_tile_bbox(center, radius, tile_bounds) @@ -278,3 +291,34 @@ def project_gaussians_forward( conics = conic return cov3d, xys, depths, radii, conics, num_tiles_hit, mask + + +def map_gaussian_to_intersects( + num_points, xys, depths, radii, cum_tiles_hit, tile_bounds +): + num_intersects = cum_tiles_hit[-1] + isect_ids = torch.zeros(num_intersects, dtype=torch.int64, device=xys.device) + gaussian_ids = torch.zeros(num_intersects, dtype=torch.int32, device=xys.device) + + for idx in range(num_points): + if radii[idx] <= 0: + break + + tile_min, tile_max = get_tile_bbox(xys[idx], radii[idx], tile_bounds) + + cur_idx = 0 if idx == 0 else cum_tiles_hit[idx - 1] + + # Get raw byte representation of the float value at the given index + raw_bytes = struct.pack("f", depths[idx]) + + # Interpret those bytes as an int32_t + depth_id_n = struct.unpack("i", raw_bytes)[0] + + for i in range(tile_min[1], tile_max[1]): + for j in range(tile_min[0], tile_max[0]): + tile_id = i * tile_bounds[0] + j + isect_ids[cur_idx] = (tile_id << 32) | depth_id_n + gaussian_ids[cur_idx] = idx + cur_idx += 1 + + return isect_ids, gaussian_ids diff --git a/diff_rast/cuda/__init__.py b/diff_rast/cuda/__init__.py index f995b41fe..d418081cd 100644 --- a/diff_rast/cuda/__init__.py +++ b/diff_rast/cuda/__init__.py @@ -18,3 +18,5 @@ def call_cuda(*args, **kwargs): project_gaussians_backward = _make_lazy_cuda_func("project_gaussians_backward") compute_sh_forward = _make_lazy_cuda_func("compute_sh_forward") compute_sh_backward = _make_lazy_cuda_func("compute_sh_backward") +compute_cumulative_intersects = _make_lazy_cuda_func("compute_cumulative_intersects") +map_gaussian_to_intersects = _make_lazy_cuda_func("map_gaussian_to_intersects") diff --git a/diff_rast/cuda/csrc/CMakeLists.txt b/diff_rast/cuda/csrc/CMakeLists.txt index 24cbc8364..12a4ba8e3 100644 --- a/diff_rast/cuda/csrc/CMakeLists.txt +++ b/diff_rast/cuda/csrc/CMakeLists.txt @@ -54,4 +54,4 @@ target_include_directories(check_serial_backward PRIVATE ) target_include_directories(check_serial_forward PRIVATE ${PROJECT_SOURCE_DIR}/third_party/glm -) +) \ No newline at end of file diff --git a/diff_rast/cuda/csrc/backward.cu b/diff_rast/cuda/csrc/backward.cu index a94a265cb..fedc475b1 100644 --- a/diff_rast/cuda/csrc/backward.cu +++ b/diff_rast/cuda/csrc/backward.cu @@ -4,7 +4,7 @@ namespace cg = cooperative_groups; -template +template __global__ void rasterize_backward_kernel( const dim3 tile_bounds, const dim3 img_size, @@ -49,7 +49,8 @@ __global__ void rasterize_backward_kernel( float T_final = final_Ts[pix_id]; float T = T_final; // the contribution from gaussians behind the current one - float S[CHANNELS] = {0.f}; // TODO: this currently doesn't match the channel count input. + float S[CHANNELS] = { + 0.f}; // TODO: this currently doesn't match the channel count input. // S[0] = 0.0; // S[1] = 0.0; // S[2] = 0.0; @@ -96,7 +97,6 @@ __global__ void rasterize_backward_kernel( S[c] += rgbs[CHANNELS * g + c] * fac; } - // v_alpha = (rgb.x * T - S.x * ra) * v_out.x // + (rgb.y * T - S.y * ra) * v_out.y // + (rgb.z * T - S.z * ra) * v_out.z; @@ -146,7 +146,7 @@ void rasterize_backward_impl( float *v_opacity ) { - rasterize_backward_kernel<3> <<>>( + rasterize_backward_kernel<3><<>>( tile_bounds, img_size, gaussians_ids_sorted, diff --git a/diff_rast/cuda/csrc/bindings.cu b/diff_rast/cuda/csrc/bindings.cu index 429f74e25..672384b25 100644 --- a/diff_rast/cuda/csrc/bindings.cu +++ b/diff_rast/cuda/csrc/bindings.cu @@ -8,6 +8,8 @@ #include #include #include +#include +#include #include #include #include @@ -29,8 +31,9 @@ __global__ void compute_cov2d_bounds_forward_kernel( float3 conic; float radius; float3 cov2d{ - (float)covs2d[index], (float)covs2d[index + 1], (float)covs2d[index + 2] - }; + (float)covs2d[index], + (float)covs2d[index + 1], + (float)covs2d[index + 2]}; compute_cov2d_bounds(cov2d, conic, radius); conics[index] = conic.x; conics[index + 1] = conic.y; @@ -253,3 +256,73 @@ project_gaussians_backward_tensor( return std::make_tuple(v_cov2d, v_cov3d, v_mean3d, v_scale, v_quat); } + +std::tuple compute_cumulative_intersects_tensor( + const int num_points, torch::Tensor &num_tiles_hit +) { + // ref: + // https://nvlabs.github.io/cub/structcub_1_1_device_scan.html#a9416ac1ea26f9fde669d83ddc883795a + // allocate sum workspace + CHECK_INPUT(num_tiles_hit); + + torch::Tensor cum_tiles_hit = torch::zeros( + {num_points}, num_tiles_hit.options().dtype(torch::kInt32) + ); + + int32_t num_intersects; + compute_cumulative_intersects( + num_points, + num_tiles_hit.contiguous().data_ptr(), + num_intersects, + cum_tiles_hit.contiguous().data_ptr() + ); + + return std::make_tuple( + torch::tensor( + num_intersects, num_tiles_hit.options().dtype(torch::kInt32) + ), + cum_tiles_hit + ); +} + +std::tuple map_gaussian_to_intersects_tensor( + const int num_points, + torch::Tensor &xys, + torch::Tensor &depths, + torch::Tensor &radii, + torch::Tensor &cum_tiles_hit, + const std::tuple tile_bounds +) { + CHECK_INPUT(xys); + CHECK_INPUT(depths); + CHECK_INPUT(radii); + CHECK_INPUT(cum_tiles_hit); + + dim3 tile_bounds_dim3; + tile_bounds_dim3.x = std::get<0>(tile_bounds); + tile_bounds_dim3.y = std::get<1>(tile_bounds); + tile_bounds_dim3.z = std::get<2>(tile_bounds); + + int32_t num_intersects = cum_tiles_hit[num_points - 1].item(); + + torch::Tensor gaussian_ids_unsorted = + torch::zeros({num_intersects}, xys.options().dtype(torch::kInt32)); + torch::Tensor isect_ids_unsorted = + torch::zeros({num_intersects}, xys.options().dtype(torch::kInt64)); + + map_gaussian_to_intersects<<< + (num_points + N_THREADS - 1) / N_THREADS, + N_THREADS>>>( + num_points, + (float2 *)xys.contiguous().data_ptr(), + depths.contiguous().data_ptr(), + radii.contiguous().data_ptr(), + cum_tiles_hit.contiguous().data_ptr(), + tile_bounds_dim3, + // Outputs. + isect_ids_unsorted.contiguous().data_ptr(), + gaussian_ids_unsorted.contiguous().data_ptr() + ); + + return std::make_tuple(isect_ids_unsorted, gaussian_ids_unsorted); +} diff --git a/diff_rast/cuda/csrc/bindings.h b/diff_rast/cuda/csrc/bindings.h index 0b392fdff..574c121de 100644 --- a/diff_rast/cuda/csrc/bindings.h +++ b/diff_rast/cuda/csrc/bindings.h @@ -79,3 +79,16 @@ project_gaussians_backward_tensor( torch::Tensor &v_xy, torch::Tensor &v_conic ); + +std::tuple compute_cumulative_intersects_tensor( + const int num_points, torch::Tensor &num_tiles_hit +); + +std::tuple map_gaussian_to_intersects_tensor( + const int num_points, + torch::Tensor &xys, + torch::Tensor &depths, + torch::Tensor &radii, + torch::Tensor &cum_tiles_hit, + const std::tuple tile_bounds +); \ No newline at end of file diff --git a/diff_rast/cuda/csrc/ext.cpp b/diff_rast/cuda/csrc/ext.cpp index 9e77bc3cf..1df3cd8ae 100644 --- a/diff_rast/cuda/csrc/ext.cpp +++ b/diff_rast/cuda/csrc/ext.cpp @@ -10,4 +10,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("project_gaussians_backward", &project_gaussians_backward_tensor); m.def("compute_sh_forward", &compute_sh_forward_tensor); m.def("compute_sh_backward", &compute_sh_backward_tensor); + m.def("compute_cumulative_intersects", &compute_cumulative_intersects_tensor); + m.def("map_gaussian_to_intersects", &map_gaussian_to_intersects_tensor); } diff --git a/diff_rast/cuda/csrc/forward.cuh b/diff_rast/cuda/csrc/forward.cuh index 34ab8f84d..624b8ef50 100644 --- a/diff_rast/cuda/csrc/forward.cuh +++ b/diff_rast/cuda/csrc/forward.cuh @@ -78,3 +78,14 @@ __host__ __device__ float3 project_cov3d_ewa( __host__ __device__ void scale_rot_to_cov3d( const float3 scale, const float glob_scale, const float4 quat, float *cov3d ); + +__global__ void map_gaussian_to_intersects( + const int num_points, + const float2 *xys, + const float *depths, + const int *radii, + const int32_t *cum_tiles_hit, + const dim3 tile_bounds, + int64_t *isect_ids, + int32_t *gaussian_ids +); \ No newline at end of file diff --git a/diff_rast/cuda/csrc/helpers.cuh b/diff_rast/cuda/csrc/helpers.cuh index 884857d94..962c29dfc 100644 --- a/diff_rast/cuda/csrc/helpers.cuh +++ b/diff_rast/cuda/csrc/helpers.cuh @@ -31,7 +31,8 @@ inline __host__ __device__ void get_tile_bbox( uint2 &tile_min, uint2 &tile_max ) { - // gets gaussian dimensions in tile space, i.e. the span of a gaussian in tile_grid (image divided into tiles) + // gets gaussian dimensions in tile space, i.e. the span of a gaussian in + // tile_grid (image divided into tiles) float2 tile_center = { pix_center.x / (float)BLOCK_X, pix_center.y / (float)BLOCK_Y}; float2 tile_radius = { @@ -44,7 +45,8 @@ compute_cov2d_bounds(const float3 cov2d, float3 &conic, float &radius) { // find eigenvalues of 2d covariance matrix // expects upper triangular values of cov matrix as float3 // then compute the radius and conic dimensions - // the conic is the inverse cov2d matrix, represented here with upper triangular values. + // the conic is the inverse cov2d matrix, represented here with upper + // triangular values. float det = cov2d.x * cov2d.z - cov2d.y * cov2d.y; if (det == 0.f) return false; @@ -64,8 +66,9 @@ compute_cov2d_bounds(const float3 cov2d, float3 &conic, float &radius) { } // compute vjp from df/d_conic to df/c_cov2d -inline __host__ __device__ void -cov2d_to_conic_vjp(const float3 &conic, const float3 &v_conic, float3 &v_cov2d) { +inline __host__ __device__ void cov2d_to_conic_vjp( + const float3 &conic, const float3 &v_conic, float3 &v_cov2d +) { // conic = inverse cov2d // df/d_cov2d = -conic * df/d_conic * conic glm::mat2 X = glm::mat2(conic.x, conic.y, conic.y, conic.z); @@ -77,7 +80,8 @@ cov2d_to_conic_vjp(const float3 &conic, const float3 &v_conic, float3 &v_cov2d) } // helper for applying R * p + T, expect mat to be ROW MAJOR -inline __host__ __device__ float3 transform_4x3(const float *mat, const float3 p) { +inline __host__ __device__ float3 +transform_4x3(const float *mat, const float3 p) { float3 out = { mat[0] * p.x + mat[1] * p.y + mat[2] * p.z + mat[3], mat[4] * p.x + mat[5] * p.y + mat[6] * p.z + mat[7], @@ -88,7 +92,8 @@ inline __host__ __device__ float3 transform_4x3(const float *mat, const float3 p // helper to apply 4x4 transform to 3d vector, return homo coords // expects mat to be ROW MAJOR -inline __host__ __device__ float4 transform_4x4(const float *mat, const float3 p) { +inline __host__ __device__ float4 +transform_4x4(const float *mat, const float3 p) { float4 out = { mat[0] * p.x + mat[1] * p.y + mat[2] * p.z + mat[3], mat[4] * p.x + mat[5] * p.y + mat[6] * p.z + mat[7], @@ -117,8 +122,7 @@ inline __host__ __device__ float3 project_pix_vjp( float3 v_ndc = {0.5f * img_size.x * v_xy.x, 0.5f * img_size.y * v_xy.y}; float4 v_proj = { - v_ndc.x * rw, v_ndc.y * rw, 0., -(v_ndc.x + v_ndc.y) * rw * rw - }; + v_ndc.x * rw, v_ndc.y * rw, 0., -(v_ndc.x + v_ndc.y) * rw * rw}; // df / d_world = df / d_cam * d_cam / d_world // = v_proj * P[:3, :3] return { @@ -164,36 +168,36 @@ quat_to_rotmat_vjp(const float4 quat, const glm::mat3 v_R) { float4 v_quat; // v_R is COLUMN MAJOR // w element stored in x field - v_quat.x = 2.f * ( - // v_quat.w = 2.f * ( - x * (v_R[1][2] - v_R[2][1]) - + y * (v_R[2][0] - v_R[0][2]) - + z * (v_R[0][1] - v_R[1][0]) - ); + v_quat.x = + 2.f * ( + // v_quat.w = 2.f * ( + x * (v_R[1][2] - v_R[2][1]) + y * (v_R[2][0] - v_R[0][2]) + + z * (v_R[0][1] - v_R[1][0]) + ); // x element in y field - v_quat.y = 2.f * ( - // v_quat.x = 2.f * ( - -2.f * x * (v_R[1][1] + v_R[2][2]) - + y * (v_R[0][1] + v_R[1][0]) - + z * (v_R[0][2] + v_R[2][0]) - + w * (v_R[1][2] - v_R[2][1]) - ); + v_quat.y = + 2.f * + ( + // v_quat.x = 2.f * ( + -2.f * x * (v_R[1][1] + v_R[2][2]) + y * (v_R[0][1] + v_R[1][0]) + + z * (v_R[0][2] + v_R[2][0]) + w * (v_R[1][2] - v_R[2][1]) + ); // y element in z field - v_quat.z = 2.f * ( - // v_quat.y = 2.f * ( - x * (v_R[0][1] + v_R[1][0]) - - 2.f * y * (v_R[0][0] + v_R[2][2]) - + z * (v_R[1][2] + v_R[2][1]) - + w * (v_R[2][0] - v_R[0][2]) - ); + v_quat.z = + 2.f * + ( + // v_quat.y = 2.f * ( + x * (v_R[0][1] + v_R[1][0]) - 2.f * y * (v_R[0][0] + v_R[2][2]) + + z * (v_R[1][2] + v_R[2][1]) + w * (v_R[2][0] - v_R[0][2]) + ); // z element in w field - v_quat.w = 2.f * ( - // v_quat.z = 2.f * ( - x * (v_R[0][2] + v_R[2][0]) - + y * (v_R[1][2] + v_R[2][1]) - - 2.f * z * (v_R[0][0] + v_R[1][1]) - + w * (v_R[0][1] - v_R[1][0]) - ); + v_quat.w = + 2.f * + ( + // v_quat.z = 2.f * ( + x * (v_R[0][2] + v_R[2][0]) + y * (v_R[1][2] + v_R[2][1]) - + 2.f * z * (v_R[0][0] + v_R[1][1]) + w * (v_R[0][1] - v_R[1][0]) + ); return v_quat; } @@ -207,8 +211,9 @@ scale_to_mat(const float3 scale, const float glob_scale) { } // device helper for culling near points -inline __host__ __device__ bool -clip_near_plane(const float3 p, const float *viewmat, float3 &p_view, float thresh) { +inline __host__ __device__ bool clip_near_plane( + const float3 p, const float *viewmat, float3 &p_view, float thresh +) { p_view = transform_4x3(viewmat, p); if (p_view.z <= thresh) { return true; diff --git a/diff_rast/cuda/csrc/serial_backward.cu b/diff_rast/cuda/csrc/serial_backward.cu index 14a28d3d8..a902b3234 100644 --- a/diff_rast/cuda/csrc/serial_backward.cu +++ b/diff_rast/cuda/csrc/serial_backward.cu @@ -164,8 +164,8 @@ computeConicBackward(const float3 &cov2D, const float3 &dL_dconic) { float denom2inv = 1.0f / ((denom * denom) + 0.0000001f); if (denom2inv != 0) { - // This is slightly different from the original implementation, but we include this line to make - // equality checks easier. + // This is slightly different from the original implementation, but we + // include this line to make equality checks easier. float denom2inv = 1.0f / (denom * denom); // Gradients of loss w.r.t. entries of 2D covariance matrix, // given gradients of loss w.r.t. conic matrix (inverse covariance diff --git a/diff_rast/cuda/csrc/serial_backward.cuh b/diff_rast/cuda/csrc/serial_backward.cuh index d26d38c60..58b5930f3 100644 --- a/diff_rast/cuda/csrc/serial_backward.cuh +++ b/diff_rast/cuda/csrc/serial_backward.cuh @@ -1,23 +1,20 @@ #include "cuda_runtime.h" - __host__ __device__ float3 projectMean2DBackward( - const float3 m, const float* proj, const float2 dL_dmean2D + const float3 m, const float *proj, const float2 dL_dmean2D ); __host__ __device__ void computeCov3DBackward( const float3 scale, const float mod, const float4 rot, - const float* dL_dcov3D, + const float *dL_dcov3D, float3 &dL_dscale, float4 &dL_dq ); -__host__ __device__ float3 computeConicBackward( - const float3 &cov2D, - const float3 &dL_dconic -); +__host__ __device__ float3 +computeConicBackward(const float3 &cov2D, const float3 &dL_dconic); __host__ __device__ void computeCov2DBackward( const float3 &mean, @@ -25,7 +22,7 @@ __host__ __device__ void computeCov2DBackward( const float *view_matrix, const float h_x, const float h_y, - const float tan_fovx, + const float tan_fovx, const float tan_fovy, const float3 &dL_dcov2d, float3 &dL_dmean, diff --git a/diff_rast/cuda/csrc/sh.cuh b/diff_rast/cuda/csrc/sh.cuh index eecfb654d..f709e155d 100644 --- a/diff_rast/cuda/csrc/sh.cuh +++ b/diff_rast/cuda/csrc/sh.cuh @@ -10,8 +10,7 @@ __host__ __device__ const float SH_C2[] = { -1.0925484305920792f, 0.31539156525252005f, -1.0925484305920792f, - 0.5462742152960396f -}; + 0.5462742152960396f}; __host__ __device__ const float SH_C3[] = { -0.5900435899266435f, 2.890611442640554f, @@ -19,8 +18,7 @@ __host__ __device__ const float SH_C3[] = { 0.3731763325901154f, -0.4570457994644658f, 1.445305721320277f, - -0.5900435899266435f -}; + -0.5900435899266435f}; __host__ __device__ const float SH_C4[] = { 2.5033429417967046f, -1.7701307697799304, @@ -30,8 +28,7 @@ __host__ __device__ const float SH_C4[] = { -0.6690465435572892f, 0.47308734787878004f, -1.7701307697799304f, - 0.6258357354491761f -}; + 0.6258357354491761f}; __host__ __device__ unsigned num_sh_bases(const unsigned degree) { if (degree == 0) diff --git a/tests/test_cumulative_intersects.py b/tests/test_cumulative_intersects.py new file mode 100644 index 000000000..bebef8bc0 --- /dev/null +++ b/tests/test_cumulative_intersects.py @@ -0,0 +1,32 @@ +import pytest +import torch + + +device = torch.device("cuda:0") + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device") +def test_cumulative_intersects(): + import diff_rast.cuda as _C + + torch.manual_seed(42) + + num_points = 10 + + num_tiles_hit = torch.randint( + 0, 100, (num_points,), device=device, dtype=torch.int32 + ) + + num_intersects, cum_tiles_hit = _C.compute_cumulative_intersects( + num_points, num_tiles_hit + ) + + _cum_tiles_hit = torch.cumsum(num_tiles_hit, dim=0, dtype=torch.int32) + _num_intersects = _cum_tiles_hit[-1] + + torch.testing.assert_close(num_intersects, _num_intersects) + torch.testing.assert_close(cum_tiles_hit, _cum_tiles_hit) + + +if __name__ == "__main__": + test_cumulative_intersects() diff --git a/tests/test_map_gaussians.py b/tests/test_map_gaussians.py new file mode 100644 index 000000000..394fe7199 --- /dev/null +++ b/tests/test_map_gaussians.py @@ -0,0 +1,68 @@ +import pytest +import torch + + +device = torch.device("cuda:0") + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device") +def test_map_gaussians(): + from diff_rast import _torch_impl + import diff_rast.cuda as _C + + torch.manual_seed(42) + + num_points = 100 + means3d = torch.randn((num_points, 3), device=device, requires_grad=True) + scales = torch.randn((num_points, 3), device=device) + glob_scale = 0.3 + quats = torch.randn((num_points, 4), device=device) + quats /= torch.linalg.norm(quats, dim=-1, keepdim=True) + viewmat = torch.eye(4, device=device) + projmat = torch.eye(4, device=device) + fx, fy = 3.0, 3.0 + H, W = 512, 512 + clip_thresh = 0.01 + + BLOCK_X, BLOCK_Y = 16, 16 + tile_bounds = (W + BLOCK_X - 1) // BLOCK_X, (H + BLOCK_Y - 1) // BLOCK_Y, 1 + + ( + _cov3d, + _xys, + _depths, + _radii, + _conics, + _num_tiles_hit, + _masks, + ) = _torch_impl.project_gaussians_forward( + means3d, + scales, + glob_scale, + quats, + viewmat, + projmat, + fx, + fy, + (H, W), + tile_bounds, + clip_thresh, + ) + + _cum_tiles_hit = torch.cumsum(_num_tiles_hit, dim=0, dtype=torch.int32) + _depths = _depths.contiguous() + + isect_ids, gaussian_ids = _C.map_gaussian_to_intersects( + num_points, _xys, _depths, _radii, _cum_tiles_hit, tile_bounds + ) + + _isect_ids, _gaussian_ids = _torch_impl.map_gaussian_to_intersects( + num_points, _xys, _depths, _radii, _cum_tiles_hit, tile_bounds + ) + + torch.testing.assert_close(gaussian_ids, _gaussian_ids) + torch.testing.assert_close(isect_ids, _isect_ids) + + +if __name__ == "__main__": + test_map_gaussians() diff --git a/tests/test_project_gaussians.py b/tests/test_project_gaussians.py new file mode 100644 index 000000000..b0e761e43 --- /dev/null +++ b/tests/test_project_gaussians.py @@ -0,0 +1,87 @@ +import pytest +import torch + + +device = torch.device("cuda:0") + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device") +def test_project_gaussians_forward(): + from diff_rast import _torch_impl + import diff_rast.cuda as _C + + torch.manual_seed(42) + + num_points = 100 + means3d = torch.randn((num_points, 3), device=device, requires_grad=True) + scales = torch.randn((num_points, 3), device=device) + glob_scale = 0.3 + quats = torch.randn((num_points, 4), device=device) + quats /= torch.linalg.norm(quats, dim=-1, keepdim=True) + viewmat = torch.eye(4, device=device) + projmat = torch.eye(4, device=device) + fx, fy = 3.0, 3.0 + H, W = 512, 512 + clip_thresh = 0.01 + + BLOCK_X, BLOCK_Y = 16, 16 + tile_bounds = (W + BLOCK_X - 1) // BLOCK_X, (H + BLOCK_Y - 1) // BLOCK_Y, 1 + + (cov3d, xys, depths, radii, conics, num_tiles_hit,) = _C.project_gaussians_forward( + num_points, + means3d, + scales, + glob_scale, + quats, + viewmat, + projmat, + fx, + fy, + H, + W, + tile_bounds, + clip_thresh, + ) + + ( + _cov3d, + _xys, + _depths, + _radii, + _conics, + _num_tiles_hit, + _masks, + ) = _torch_impl.project_gaussians_forward( + means3d, + scales, + glob_scale, + quats, + viewmat, + projmat, + fx, + fy, + (H, W), + tile_bounds, + clip_thresh, + ) + + torch.testing.assert_close( + cov3d[_masks], + _cov3d.view(-1, 9)[_masks][:, [0, 1, 2, 4, 5, 8]], + atol=1e-5, + rtol=1e-5, + ) + torch.testing.assert_close( + xys[_masks], + _xys[_masks], + atol=1e-4, + rtol=1e-4, + ) + torch.testing.assert_close(depths[_masks], _depths[_masks]) + torch.testing.assert_close(radii[_masks], _radii[_masks]) + torch.testing.assert_close(conics[_masks], _conics[_masks]) + torch.testing.assert_close(num_tiles_hit[_masks], _num_tiles_hit[_masks]) + + +if __name__ == "__main__": + test_project_gaussians_forward()