Skip to content

Commit

Permalink
Use int64_t instead of long
Browse files Browse the repository at this point in the history
  • Loading branch information
skywolf829 committed Oct 2, 2024
1 parent 8a66878 commit 9672603
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions nerfacc/cuda/csrc/scan_cub.cu
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,13 @@ torch::Tensor inclusive_sum_cub(
#if CUB_SUPPORTS_SCAN_BY_KEY()
if (backward) {
inclusive_sum_by_key(
thrust::make_reverse_iterator(indices.data_ptr<long>() + n_edges),
thrust::make_reverse_iterator(indices.data_ptr<int64_t>() + n_edges),
thrust::make_reverse_iterator(inputs.data_ptr<float>() + n_edges),
thrust::make_reverse_iterator(outputs.data_ptr<float>() + n_edges),
n_edges);
} else {
inclusive_sum_by_key(
indices.data_ptr<long>(),
indices.data_ptr<int64_t>(),
inputs.data_ptr<float>(),
outputs.data_ptr<float>(),
n_edges);
Expand Down Expand Up @@ -129,13 +129,13 @@ torch::Tensor exclusive_sum_cub(
#if CUB_SUPPORTS_SCAN_BY_KEY()
if (backward) {
exclusive_sum_by_key(
thrust::make_reverse_iterator(indices.data_ptr<long>() + n_edges),
thrust::make_reverse_iterator(indices.data_ptr<int64_t>() + n_edges),
thrust::make_reverse_iterator(inputs.data_ptr<float>() + n_edges),
thrust::make_reverse_iterator(outputs.data_ptr<float>() + n_edges),
n_edges);
} else {
exclusive_sum_by_key(
indices.data_ptr<long>(),
indices.data_ptr<int64_t>(),
inputs.data_ptr<float>(),
outputs.data_ptr<float>(),
n_edges);
Expand Down Expand Up @@ -169,7 +169,7 @@ torch::Tensor inclusive_prod_cub_forward(

#if CUB_SUPPORTS_SCAN_BY_KEY()
inclusive_prod_by_key(
indices.data_ptr<long>(),
indices.data_ptr<int64_t>(),
inputs.data_ptr<float>(),
outputs.data_ptr<float>(),
n_edges);
Expand Down Expand Up @@ -203,7 +203,7 @@ torch::Tensor inclusive_prod_cub_backward(
}
#if CUB_SUPPORTS_SCAN_BY_KEY()
inclusive_sum_by_key(
thrust::make_reverse_iterator(indices.data_ptr<long>() + n_edges),
thrust::make_reverse_iterator(indices.data_ptr<int64_t>() + n_edges),
thrust::make_reverse_iterator((grad_outputs * outputs).data_ptr<float>() + n_edges),
thrust::make_reverse_iterator(grad_inputs.data_ptr<float>() + n_edges),
n_edges);
Expand Down Expand Up @@ -237,7 +237,7 @@ torch::Tensor exclusive_prod_cub_forward(
}
#if CUB_SUPPORTS_SCAN_BY_KEY()
exclusive_prod_by_key(
indices.data_ptr<long>(),
indices.data_ptr<int64_t>(),
inputs.data_ptr<float>(),
outputs.data_ptr<float>(),
n_edges);
Expand Down Expand Up @@ -272,7 +272,7 @@ torch::Tensor exclusive_prod_cub_backward(

#if CUB_SUPPORTS_SCAN_BY_KEY()
exclusive_sum_by_key(
thrust::make_reverse_iterator(indices.data_ptr<long>() + n_edges),
thrust::make_reverse_iterator(indices.data_ptr<int64_t>() + n_edges),
thrust::make_reverse_iterator((grad_outputs * outputs).data_ptr<float>() + n_edges),
thrust::make_reverse_iterator(grad_inputs.data_ptr<float>() + n_edges),
n_edges);
Expand Down

0 comments on commit 9672603

Please sign in to comment.