Skip to content

Commit

Permalink
add device guard (#135)
Browse files Browse the repository at this point in the history
  • Loading branch information
liruilong940607 authored Feb 24, 2024
1 parent a45e203 commit 24215cb
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 2 deletions.
13 changes: 11 additions & 2 deletions gsplat/cuda/csrc/bindings.cu
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ std::tuple<
torch::Tensor, // output conics
torch::Tensor> // output radii
compute_cov2d_bounds_tensor(const int num_pts, torch::Tensor &covs2d) {
DEVICE_GUARD(covs2d);
CHECK_INPUT(covs2d);
torch::Tensor conics = torch::zeros(
{num_pts, covs2d.size(1)}, covs2d.options().dtype(torch::kFloat32)
Expand All @@ -65,6 +66,7 @@ torch::Tensor compute_sh_forward_tensor(
torch::Tensor &viewdirs,
torch::Tensor &coeffs
) {
DEVICE_GUARD(viewdirs);
unsigned num_bases = num_sh_bases(degree);
if (coeffs.ndimension() != 3 || coeffs.size(0) != num_points ||
coeffs.size(1) != num_bases || coeffs.size(2) != 3) {
Expand All @@ -91,6 +93,7 @@ torch::Tensor compute_sh_backward_tensor(
torch::Tensor &viewdirs,
torch::Tensor &v_colors
) {
DEVICE_GUARD(viewdirs);
if (viewdirs.ndimension() != 2 || viewdirs.size(0) != num_points ||
viewdirs.size(1) != 3) {
AT_ERROR("viewdirs must have dimensions (N, 3)");
Expand Down Expand Up @@ -141,6 +144,7 @@ project_gaussians_forward_tensor(
const unsigned block_width,
const float clip_thresh
) {
DEVICE_GUARD(means3d);
dim3 img_size_dim3;
img_size_dim3.x = img_width;
img_size_dim3.y = img_height;
Expand Down Expand Up @@ -225,6 +229,7 @@ project_gaussians_backward_tensor(
torch::Tensor &v_depth,
torch::Tensor &v_conic
) {
DEVICE_GUARD(means3d);
dim3 img_size_dim3;
img_size_dim3.x = img_width;
img_size_dim3.y = img_height;
Expand Down Expand Up @@ -284,6 +289,7 @@ std::tuple<torch::Tensor, torch::Tensor> map_gaussian_to_intersects_tensor(
const std::tuple<int, int, int> tile_bounds,
const unsigned block_width
) {
DEVICE_GUARD(xys);
CHECK_INPUT(xys);
CHECK_INPUT(depths);
CHECK_INPUT(radii);
Expand Down Expand Up @@ -321,6 +327,7 @@ torch::Tensor get_tile_bin_edges_tensor(
int num_intersects, const torch::Tensor &isect_ids_sorted,
const std::tuple<int, int, int> tile_bounds
) {
DEVICE_GUARD(isect_ids_sorted);
CHECK_INPUT(isect_ids_sorted);
int num_tiles = std::get<0>(tile_bounds) * std::get<1>(tile_bounds);
torch::Tensor tile_bins = torch::zeros(
Expand Down Expand Up @@ -349,6 +356,7 @@ rasterize_forward_tensor(
const torch::Tensor &opacities,
const torch::Tensor &background
) {
DEVICE_GUARD(xys);
CHECK_INPUT(gaussian_ids_sorted);
CHECK_INPUT(tile_bins);
CHECK_INPUT(xys);
Expand Down Expand Up @@ -418,6 +426,7 @@ nd_rasterize_forward_tensor(
const torch::Tensor &opacities,
const torch::Tensor &background
) {
DEVICE_GUARD(xys);
CHECK_INPUT(gaussian_ids_sorted);
CHECK_INPUT(tile_bins);
CHECK_INPUT(xys);
Expand Down Expand Up @@ -504,7 +513,7 @@ std::
const torch::Tensor &v_output, // dL_dout_color
const torch::Tensor &v_output_alpha // dL_dout_alpha
) {

DEVICE_GUARD(xys);
CHECK_INPUT(xys);
CHECK_INPUT(colors);

Expand Down Expand Up @@ -585,7 +594,7 @@ std::
const torch::Tensor &v_output, // dL_dout_color
const torch::Tensor &v_output_alpha // dL_dout_alpha
) {

DEVICE_GUARD(xys);
CHECK_INPUT(xys);
CHECK_INPUT(colors);

Expand Down
3 changes: 3 additions & 0 deletions gsplat/cuda/csrc/bindings.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@
#include <math.h>
#include <torch/extension.h>
#include <tuple>
#include <c10/cuda/CUDAGuard.h>

#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
#define DEVICE_GUARD(_ten) \
const at::cuda::OptionalCUDAGuard device_guard(device_of(_ten));

std::tuple<
torch::Tensor, // output conics
Expand Down

0 comments on commit 24215cb

Please sign in to comment.