From a8ec61fc559a8721b58665465d89073885d1e556 Mon Sep 17 00:00:00 2001 From: gemaozhou Date: Wed, 25 Oct 2023 19:38:43 +0800 Subject: [PATCH] Optimize bev_pool_grad_kernel --- mmdet3d/ops/bev_pool_v2/bev_pool.py | 2 +- mmdet3d/ops/bev_pool_v2/src/bev_pool.cpp | 33 +++++++++++++++++ mmdet3d/ops/bev_pool_v2/src/bev_pool_cuda.cu | 39 ++++++++++++++++++++ 3 files changed, 73 insertions(+), 1 deletion(-) diff --git a/mmdet3d/ops/bev_pool_v2/bev_pool.py b/mmdet3d/ops/bev_pool_v2/bev_pool.py index 192c39e0..c2fff115 100644 --- a/mmdet3d/ops/bev_pool_v2/bev_pool.py +++ b/mmdet3d/ops/bev_pool_v2/bev_pool.py @@ -67,7 +67,7 @@ def backward(ctx, out_grad): depth_grad = depth.new_zeros(depth.shape) feat_grad = feat.new_zeros(feat.shape) out_grad = out_grad.contiguous() - bev_pool_v2_ext.bev_pool_v2_backward( + bev_pool_v2_ext.bev_pool_v2_backward_opt( out_grad, depth_grad, feat_grad, diff --git a/mmdet3d/ops/bev_pool_v2/src/bev_pool.cpp b/mmdet3d/ops/bev_pool_v2/src/bev_pool.cpp index c7c38f69..7a217089 100644 --- a/mmdet3d/ops/bev_pool_v2/src/bev_pool.cpp +++ b/mmdet3d/ops/bev_pool_v2/src/bev_pool.cpp @@ -13,6 +13,12 @@ void bev_pool_v2_grad(int c, int n_intervals, const float* out_grad, const int* ranks_bev, const int* interval_starts, const int* interval_lengths, float* depth_grad, float* feat_grad); +void bev_pool_v2_grad_opt(int c, int n_intervals, const float *out_grad, + const float *depth, const float *feat, + const int *ranks_depth, const int *ranks_feat, + const int *ranks_bev, const int *interval_starts, + const int *interval_lengths, float *depth_grad, + float *feat_grad); /* Function: pillar pooling (forward, cuda) @@ -103,9 +109,36 @@ void bev_pool_v2_backward( ); } +void bev_pool_v2_backward_opt( + const at::Tensor _out_grad, at::Tensor _depth_grad, at::Tensor _feat_grad, + const at::Tensor _depth, const at::Tensor _feat, + const at::Tensor _ranks_depth, const at::Tensor _ranks_feat, + const at::Tensor _ranks_bev, const at::Tensor _interval_lengths, + const at::Tensor _interval_starts) { + int c = _out_grad.size(4); + int n_intervals = _interval_lengths.size(0); + const at::cuda::OptionalCUDAGuard device_guard(device_of(_out_grad)); + const float *out_grad = _out_grad.data_ptr(); + float *depth_grad = _depth_grad.data_ptr(); + float *feat_grad = _feat_grad.data_ptr(); + const float *depth = _depth.data_ptr(); + const float *feat = _feat.data_ptr(); + const int *ranks_depth = _ranks_depth.data_ptr(); + const int *ranks_feat = _ranks_feat.data_ptr(); + const int *ranks_bev = _ranks_bev.data_ptr(); + const int *interval_lengths = _interval_lengths.data_ptr(); + const int *interval_starts = _interval_starts.data_ptr(); + + bev_pool_v2_grad_opt(c, n_intervals, out_grad, depth, feat, ranks_depth, + ranks_feat, ranks_bev, interval_starts, + interval_lengths, depth_grad, feat_grad); +} + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("bev_pool_v2_forward", &bev_pool_v2_forward, "bev_pool_v2_forward"); m.def("bev_pool_v2_backward", &bev_pool_v2_backward, "bev_pool_v2_backward"); + m.def("bev_pool_v2_backward_opt", &bev_pool_v2_backward_opt, + "bev_pool_v2_backward_opt"); } diff --git a/mmdet3d/ops/bev_pool_v2/src/bev_pool_cuda.cu b/mmdet3d/ops/bev_pool_v2/src/bev_pool_cuda.cu index 7fa3179b..eb156fd4 100644 --- a/mmdet3d/ops/bev_pool_v2/src/bev_pool_cuda.cu +++ b/mmdet3d/ops/bev_pool_v2/src/bev_pool_cuda.cu @@ -120,7 +120,35 @@ __global__ void bev_pool_grad_kernel(int c, int n_intervals, } } +__global__ void bev_pool_grad_kernel_opt( + int c, int n_intervals, const float *__restrict__ out_grad, + const float *__restrict__ depth, const float *__restrict__ feat, + const int *__restrict__ ranks_depth, const int *__restrict__ ranks_feat, + const int *__restrict__ ranks_bev, const int *__restrict__ interval_starts, + const int *__restrict__ interval_lengths, float *__restrict__ depth_grad, + float *__restrict__ feat_grad) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int index = idx / c; + int cur_c = idx % c; + if (index >= n_intervals) + return; + + int interval_start = interval_starts[index]; + int interval_length = interval_lengths[index]; + for (int i = 0; i < interval_length; ++i) { + const float *cur_out_grad_start = + out_grad + ranks_bev[interval_start + i] * c; + const float *cur_feat_start = feat + ranks_feat[interval_start + i] * c; + float *cur_depth_grad = depth_grad + ranks_depth[interval_start + i]; + atomicAdd(cur_depth_grad, + cur_out_grad_start[cur_c] * cur_feat_start[cur_c]); + const int cur_rank = ranks_bev[interval_start + i]; + float *cur_feat_grad = feat_grad + ranks_feat[interval_start] * c + cur_c; + *cur_feat_grad += out_grad[cur_rank * c + cur_c] * + depth[ranks_depth[interval_start + i]]; + } +} void bev_pool_v2(int c, int n_intervals, const float* depth, const float* feat, const int* ranks_depth, const int* ranks_feat, const int* ranks_bev, const int* interval_starts, const int* interval_lengths, float* out) { @@ -138,3 +166,14 @@ void bev_pool_v2_grad(int c, int n_intervals, const float* out_grad, ranks_bev, interval_starts, interval_lengths, depth_grad, feat_grad ); } + +void bev_pool_v2_grad_opt(int c, int n_intervals, const float *out_grad, + const float *depth, const float *feat, + const int *ranks_depth, const int *ranks_feat, + const int *ranks_bev, const int *interval_starts, + const int *interval_lengths, float *depth_grad, + float *feat_grad) { + bev_pool_grad_kernel_opt<<<(int)ceil(((double)n_intervals * c / 256)), 256>>>( + c, n_intervals, out_grad, depth, feat, ranks_depth, ranks_feat, ranks_bev, + interval_starts, interval_lengths, depth_grad, feat_grad); +}