diff --git a/gsplat/cuda/csrc/bindings.cu b/gsplat/cuda/csrc/bindings.cu index 16bea7f1a..aa0879ba6 100644 --- a/gsplat/cuda/csrc/bindings.cu +++ b/gsplat/cuda/csrc/bindings.cu @@ -40,6 +40,7 @@ std::tuple< torch::Tensor, // output conics torch::Tensor> // output radii compute_cov2d_bounds_tensor(const int num_pts, torch::Tensor &covs2d) { + DEVICE_GUARD(covs2d); CHECK_INPUT(covs2d); torch::Tensor conics = torch::zeros( {num_pts, covs2d.size(1)}, covs2d.options().dtype(torch::kFloat32) @@ -65,6 +66,7 @@ torch::Tensor compute_sh_forward_tensor( torch::Tensor &viewdirs, torch::Tensor &coeffs ) { + DEVICE_GUARD(viewdirs); unsigned num_bases = num_sh_bases(degree); if (coeffs.ndimension() != 3 || coeffs.size(0) != num_points || coeffs.size(1) != num_bases || coeffs.size(2) != 3) { @@ -91,6 +93,7 @@ torch::Tensor compute_sh_backward_tensor( torch::Tensor &viewdirs, torch::Tensor &v_colors ) { + DEVICE_GUARD(viewdirs); if (viewdirs.ndimension() != 2 || viewdirs.size(0) != num_points || viewdirs.size(1) != 3) { AT_ERROR("viewdirs must have dimensions (N, 3)"); @@ -141,6 +144,7 @@ project_gaussians_forward_tensor( const unsigned block_width, const float clip_thresh ) { + DEVICE_GUARD(means3d); dim3 img_size_dim3; img_size_dim3.x = img_width; img_size_dim3.y = img_height; @@ -225,6 +229,7 @@ project_gaussians_backward_tensor( torch::Tensor &v_depth, torch::Tensor &v_conic ) { + DEVICE_GUARD(means3d); dim3 img_size_dim3; img_size_dim3.x = img_width; img_size_dim3.y = img_height; @@ -284,6 +289,7 @@ std::tuple map_gaussian_to_intersects_tensor( const std::tuple tile_bounds, const unsigned block_width ) { + DEVICE_GUARD(xys); CHECK_INPUT(xys); CHECK_INPUT(depths); CHECK_INPUT(radii); @@ -321,6 +327,7 @@ torch::Tensor get_tile_bin_edges_tensor( int num_intersects, const torch::Tensor &isect_ids_sorted, const std::tuple tile_bounds ) { + DEVICE_GUARD(isect_ids_sorted); CHECK_INPUT(isect_ids_sorted); int num_tiles = std::get<0>(tile_bounds) * std::get<1>(tile_bounds); torch::Tensor tile_bins = torch::zeros( @@ -349,6 +356,7 @@ rasterize_forward_tensor( const torch::Tensor &opacities, const torch::Tensor &background ) { + DEVICE_GUARD(xys); CHECK_INPUT(gaussian_ids_sorted); CHECK_INPUT(tile_bins); CHECK_INPUT(xys); @@ -418,6 +426,7 @@ nd_rasterize_forward_tensor( const torch::Tensor &opacities, const torch::Tensor &background ) { + DEVICE_GUARD(xys); CHECK_INPUT(gaussian_ids_sorted); CHECK_INPUT(tile_bins); CHECK_INPUT(xys); @@ -504,7 +513,7 @@ std:: const torch::Tensor &v_output, // dL_dout_color const torch::Tensor &v_output_alpha // dL_dout_alpha ) { - + DEVICE_GUARD(xys); CHECK_INPUT(xys); CHECK_INPUT(colors); @@ -585,7 +594,7 @@ std:: const torch::Tensor &v_output, // dL_dout_color const torch::Tensor &v_output_alpha // dL_dout_alpha ) { - + DEVICE_GUARD(xys); CHECK_INPUT(xys); CHECK_INPUT(colors); diff --git a/gsplat/cuda/csrc/bindings.h b/gsplat/cuda/csrc/bindings.h index 7db57e86e..f6e40a904 100644 --- a/gsplat/cuda/csrc/bindings.h +++ b/gsplat/cuda/csrc/bindings.h @@ -5,6 +5,7 @@ #include #include #include +#include #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") #define CHECK_CONTIGUOUS(x) \ @@ -12,6 +13,8 @@ #define CHECK_INPUT(x) \ CHECK_CUDA(x); \ CHECK_CONTIGUOUS(x) +#define DEVICE_GUARD(_ten) \ + const at::cuda::OptionalCUDAGuard device_guard(device_of(_ten)); std::tuple< torch::Tensor, // output conics