Skip to content

Commit

Permalink
Optimize bev_pool_grad_kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
gemaozhou committed Oct 25, 2023
1 parent f71858d commit 81b080f
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 1 deletion.
2 changes: 1 addition & 1 deletion mmdet3d/ops/bev_pool_v2/bev_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def backward(ctx, out_grad):
depth_grad = depth.new_zeros(depth.shape)
feat_grad = feat.new_zeros(feat.shape)
out_grad = out_grad.contiguous()
bev_pool_v2_ext.bev_pool_v2_backward(
bev_pool_v2_ext.bev_pool_v2_backward_opt(
out_grad,
depth_grad,
feat_grad,
Expand Down
33 changes: 33 additions & 0 deletions mmdet3d/ops/bev_pool_v2/src/bev_pool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@ void bev_pool_v2_grad(int c, int n_intervals, const float* out_grad,
const int* ranks_bev, const int* interval_starts, const int* interval_lengths,
float* depth_grad, float* feat_grad);

void bev_pool_v2_grad_opt(int c, int n_intervals, const float *out_grad,
const float *depth, const float *feat,
const int *ranks_depth, const int *ranks_feat,
const int *ranks_bev, const int *interval_starts,
const int *interval_lengths, float *depth_grad,
float *feat_grad);

/*
Function: pillar pooling (forward, cuda)
Expand Down Expand Up @@ -103,9 +109,36 @@ void bev_pool_v2_backward(
);
}

void bev_pool_v2_backward_opt(
const at::Tensor _out_grad, at::Tensor _depth_grad, at::Tensor _feat_grad,
const at::Tensor _depth, const at::Tensor _feat,
const at::Tensor _ranks_depth, const at::Tensor _ranks_feat,
const at::Tensor _ranks_bev, const at::Tensor _interval_lengths,
const at::Tensor _interval_starts) {
int c = _out_grad.size(4);
int n_intervals = _interval_lengths.size(0);
const at::cuda::OptionalCUDAGuard device_guard(device_of(_out_grad));
const float *out_grad = _out_grad.data_ptr<float>();
float *depth_grad = _depth_grad.data_ptr<float>();
float *feat_grad = _feat_grad.data_ptr<float>();
const float *depth = _depth.data_ptr<float>();
const float *feat = _feat.data_ptr<float>();
const int *ranks_depth = _ranks_depth.data_ptr<int>();
const int *ranks_feat = _ranks_feat.data_ptr<int>();
const int *ranks_bev = _ranks_bev.data_ptr<int>();
const int *interval_lengths = _interval_lengths.data_ptr<int>();
const int *interval_starts = _interval_starts.data_ptr<int>();

bev_pool_v2_grad_opt(c, n_intervals, out_grad, depth, feat, ranks_depth,
ranks_feat, ranks_bev, interval_starts,
interval_lengths, depth_grad, feat_grad);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("bev_pool_v2_forward", &bev_pool_v2_forward,
"bev_pool_v2_forward");
m.def("bev_pool_v2_backward", &bev_pool_v2_backward,
"bev_pool_v2_backward");
m.def("bev_pool_v2_backward_opt", &bev_pool_v2_backward_opt,
"bev_pool_v2_backward_opt");
}
25 changes: 25 additions & 0 deletions mmdet3d/ops/bev_pool_v2/src/bev_pool_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,21 @@ __global__ void bev_pool_grad_kernel(int c, int n_intervals,
}
}

__global__ void bev_pool_grad_kernel_opt(
int c, int n_intervals, const float *__restrict__ out_grad,
const float *__restrict__ depth, const float *__restrict__ feat,
const int *__restrict__ ranks_depth, const int *__restrict__ ranks_feat,
const int *__restrict__ ranks_bev, const int *__restrict__ interval_starts,
const int *__restrict__ interval_lengths, float *__restrict__ depth_grad,
float *__restrict__ feat_grad) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int index = idx / c;
int cur_c = idx % c;
if (index >= n_intervals)
return;

int interval_start = interval_starts[index];
int interval_length = interval_lengths[index];

void bev_pool_v2(int c, int n_intervals, const float* depth, const float* feat, const int* ranks_depth,
const int* ranks_feat, const int* ranks_bev, const int* interval_starts, const int* interval_lengths, float* out) {
Expand All @@ -138,3 +152,14 @@ void bev_pool_v2_grad(int c, int n_intervals, const float* out_grad,
ranks_bev, interval_starts, interval_lengths, depth_grad, feat_grad
);
}

void bev_pool_v2_grad_opt(int c, int n_intervals, const float *out_grad,
const float *depth, const float *feat,
const int *ranks_depth, const int *ranks_feat,
const int *ranks_bev, const int *interval_starts,
const int *interval_lengths, float *depth_grad,
float *feat_grad) {
bev_pool_grad_kernel_opt<<<(int)ceil(((double)n_intervals * c / 256)), 256>>>(
c, n_intervals, out_grad, depth, feat, ranks_depth, ranks_feat, ranks_bev,
interval_starts, interval_lengths, depth_grad, feat_grad);
}

0 comments on commit 81b080f

Please sign in to comment.