From 65b936e521bba0d006e52ce5cde9dde956acbacd Mon Sep 17 00:00:00 2001 From: lbin Date: Wed, 29 Jul 2020 20:07:05 +0800 Subject: [PATCH 1/3] add pytorch 1.6 support(cpu not support yet) --- dcn_v2.py | 432 ++++++++++++++++---------- setup.py | 21 +- src/cpu/dcn_v2_cpu.cpp | 273 ++++++++++++---- src/cpu/dcn_v2_im2col_cpu.cpp | 395 +++++++++++++++++++++++ src/cpu/dcn_v2_im2col_cpu.h | 99 ++++++ src/cpu/dcn_v2_psroi_pooling_cpu.cpp | 426 +++++++++++++++++++++++++ src/cuda/dcn_v2_cuda.cu | 130 +++++--- src/cuda/dcn_v2_im2col_cuda.cu | 16 +- src/cuda/dcn_v2_psroi_pooling_cuda.cu | 36 +-- src/cuda/vision.h | 2 +- src/dcn_v2.h | 53 +++- 11 files changed, 1566 insertions(+), 317 deletions(-) create mode 100644 src/cpu/dcn_v2_im2col_cpu.cpp create mode 100644 src/cpu/dcn_v2_im2col_cpu.h create mode 100644 src/cpu/dcn_v2_psroi_pooling_cpu.cpp diff --git a/dcn_v2.py b/dcn_v2.py index 742786b..ce654b5 100644 --- a/dcn_v2.py +++ b/dcn_v2.py @@ -1,70 +1,109 @@ #!/usr/bin/env python -from __future__ import absolute_import -from __future__ import print_function -from __future__ import division +from __future__ import absolute_import, division, print_function import math + import torch from torch import nn from torch.autograd import Function -from torch.nn.modules.utils import _pair from torch.autograd.function import once_differentiable +from torch.nn.modules.utils import _pair import _ext as _backend -try: - from apex import amp -except ImportError: - raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.") - class _DCNv2(Function): @staticmethod - @amp.float_function - def forward(ctx, input, offset, mask, weight, bias, - stride, padding, dilation, deformable_groups): + def forward( + ctx, input, offset, mask, weight, bias, stride, padding, dilation, deformable_groups + ): ctx.stride = _pair(stride) ctx.padding = _pair(padding) ctx.dilation = _pair(dilation) ctx.kernel_size = _pair(weight.shape[2:4]) ctx.deformable_groups = deformable_groups - output = _backend.dcn_v2_forward(input, weight, bias, - offset, mask, - ctx.kernel_size[0], ctx.kernel_size[1], - ctx.stride[0], ctx.stride[1], - ctx.padding[0], ctx.padding[1], - ctx.dilation[0], ctx.dilation[1], - ctx.deformable_groups) + output = _backend.dcn_v2_forward( + input, + weight, + bias, + offset, + mask, + ctx.kernel_size[0], + ctx.kernel_size[1], + ctx.stride[0], + ctx.stride[1], + ctx.padding[0], + ctx.padding[1], + ctx.dilation[0], + ctx.dilation[1], + ctx.deformable_groups, + ) ctx.save_for_backward(input, offset, mask, weight, bias) return output @staticmethod @once_differentiable - @amp.float_function def backward(ctx, grad_output): input, offset, mask, weight, bias = ctx.saved_tensors - grad_input, grad_offset, grad_mask, grad_weight, grad_bias = \ - _backend.dcn_v2_backward(input, weight, - bias, - offset, mask, - grad_output, - ctx.kernel_size[0], ctx.kernel_size[1], - ctx.stride[0], ctx.stride[1], - ctx.padding[0], ctx.padding[1], - ctx.dilation[0], ctx.dilation[1], - ctx.deformable_groups) + grad_input, grad_offset, grad_mask, grad_weight, grad_bias = _backend.dcn_v2_backward( + input, + weight, + bias, + offset, + mask, + grad_output, + ctx.kernel_size[0], + ctx.kernel_size[1], + ctx.stride[0], + ctx.stride[1], + ctx.padding[0], + ctx.padding[1], + ctx.dilation[0], + ctx.dilation[1], + ctx.deformable_groups, + ) + + return grad_input, grad_offset, grad_mask, grad_weight, grad_bias, None, None, None, None - return grad_input, grad_offset, grad_mask, grad_weight, grad_bias,\ - None, None, None, None, + @staticmethod + def symbolic( + g, input, offset, mask, weight, bias, stride, padding, dilation, deformable_groups + ): + from torch.nn.modules.utils import _pair + + stride = _pair(stride) + padding = _pair(padding) + dilation = _pair(dilation) + # as of trt 7, the dcn operation will be translated again by modifying the onnx file + # so the exporting code is kept to resemble the forward() + return g.op( + "DCNv2_2", + input, + offset, + mask, + weight, + bias, + stride_i=stride, + padding_i=padding, + dilation_i=dilation, + deformable_groups_i=deformable_groups, + ) dcn_v2_conv = _DCNv2.apply class DCNv2(nn.Module): - - def __init__(self, in_channels, out_channels, - kernel_size, stride, padding, dilation=1, deformable_groups=1): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation=1, + deformable_groups=1, + ): super(DCNv2, self).__init__() self.in_channels = in_channels self.out_channels = out_channels @@ -74,8 +113,7 @@ def __init__(self, in_channels, out_channels, self.dilation = _pair(dilation) self.deformable_groups = deformable_groups - self.weight = nn.Parameter(torch.Tensor( - out_channels, in_channels, *self.kernel_size)) + self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels, *self.kernel_size)) self.bias = nn.Parameter(torch.Tensor(out_channels)) self.reset_parameters() @@ -83,39 +121,53 @@ def reset_parameters(self): n = self.in_channels for k in self.kernel_size: n *= k - stdv = 1. / math.sqrt(n) + stdv = 1.0 / math.sqrt(n) self.weight.data.uniform_(-stdv, stdv) self.bias.data.zero_() def forward(self, input, offset, mask): - assert 2 * self.deformable_groups * self.kernel_size[0] * self.kernel_size[1] == \ - offset.shape[1] - assert self.deformable_groups * self.kernel_size[0] * self.kernel_size[1] == \ - mask.shape[1] - return dcn_v2_conv(input, offset, mask, - self.weight, - self.bias, - self.stride, - self.padding, - self.dilation, - self.deformable_groups) + assert ( + 2 * self.deformable_groups * self.kernel_size[0] * self.kernel_size[1] + == offset.shape[1] + ) + assert self.deformable_groups * self.kernel_size[0] * self.kernel_size[1] == mask.shape[1] + return dcn_v2_conv( + input, + offset, + mask, + self.weight, + self.bias, + self.stride, + self.padding, + self.dilation, + self.deformable_groups, + ) class DCN(DCNv2): - - def __init__(self, in_channels, out_channels, - kernel_size, stride, padding, - dilation=1, deformable_groups=1): - super(DCN, self).__init__(in_channels, out_channels, - kernel_size, stride, padding, dilation, deformable_groups) + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation=1, + deformable_groups=1, + ): + super(DCN, self).__init__( + in_channels, out_channels, kernel_size, stride, padding, dilation, deformable_groups + ) channels_ = self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1] - self.conv_offset_mask = nn.Conv2d(self.in_channels, - channels_, - kernel_size=self.kernel_size, - stride=self.stride, - padding=self.padding, - bias=True) + self.conv_offset_mask = nn.Conv2d( + self.in_channels, + channels_, + kernel_size=self.kernel_size, + stride=self.stride, + padding=self.padding, + bias=True, + ) self.init_offset() def init_offset(self): @@ -127,26 +179,35 @@ def forward(self, input): o1, o2, mask = torch.chunk(out, 3, dim=1) offset = torch.cat((o1, o2), dim=1) mask = torch.sigmoid(mask) - return dcn_v2_conv(input, offset, mask, - self.weight, self.bias, - self.stride, - self.padding, - self.dilation, - self.deformable_groups) - + return dcn_v2_conv( + input, + offset, + mask, + self.weight, + self.bias, + self.stride, + self.padding, + self.dilation, + self.deformable_groups, + ) class _DCNv2Pooling(Function): @staticmethod - def forward(ctx, input, rois, offset, - spatial_scale, - pooled_size, - output_dim, - no_trans, - group_size=1, - part_size=None, - sample_per_part=4, - trans_std=.0): + def forward( + ctx, + input, + rois, + offset, + spatial_scale, + pooled_size, + output_dim, + no_trans, + group_size=1, + part_size=None, + sample_per_part=4, + trans_std=0.0, + ): ctx.spatial_scale = spatial_scale ctx.no_trans = int(no_trans) ctx.output_dim = output_dim @@ -156,12 +217,19 @@ def forward(ctx, input, rois, offset, ctx.sample_per_part = sample_per_part ctx.trans_std = trans_std - output, output_count = \ - _backend.dcn_v2_psroi_pooling_forward(input, rois, offset, - ctx.no_trans, ctx.spatial_scale, - ctx.output_dim, ctx.group_size, - ctx.pooled_size, ctx.part_size, - ctx.sample_per_part, ctx.trans_std) + output, output_count = _backend.dcn_v2_psroi_pooling_forward( + input, + rois, + offset, + ctx.no_trans, + ctx.spatial_scale, + ctx.output_dim, + ctx.group_size, + ctx.pooled_size, + ctx.part_size, + ctx.sample_per_part, + ctx.trans_std, + ) ctx.save_for_backward(input, rois, offset, output_count) return output @@ -169,39 +237,40 @@ def forward(ctx, input, rois, offset, @once_differentiable def backward(ctx, grad_output): input, rois, offset, output_count = ctx.saved_tensors - grad_input, grad_offset = \ - _backend.dcn_v2_psroi_pooling_backward(grad_output, - input, - rois, - offset, - output_count, - ctx.no_trans, - ctx.spatial_scale, - ctx.output_dim, - ctx.group_size, - ctx.pooled_size, - ctx.part_size, - ctx.sample_per_part, - ctx.trans_std) - - return grad_input, None, grad_offset, \ - None, None, None, None, None, None, None, None + grad_input, grad_offset = _backend.dcn_v2_psroi_pooling_backward( + grad_output, + input, + rois, + offset, + output_count, + ctx.no_trans, + ctx.spatial_scale, + ctx.output_dim, + ctx.group_size, + ctx.pooled_size, + ctx.part_size, + ctx.sample_per_part, + ctx.trans_std, + ) + + return grad_input, None, grad_offset, None, None, None, None, None, None, None, None dcn_v2_pooling = _DCNv2Pooling.apply class DCNv2Pooling(nn.Module): - - def __init__(self, - spatial_scale, - pooled_size, - output_dim, - no_trans, - group_size=1, - part_size=None, - sample_per_part=4, - trans_std=.0): + def __init__( + self, + spatial_scale, + pooled_size, + output_dim, + no_trans, + group_size=1, + part_size=None, + sample_per_part=4, + trans_std=0.0, + ): super(DCNv2Pooling, self).__init__() self.spatial_scale = spatial_scale self.pooled_size = pooled_size @@ -216,49 +285,56 @@ def forward(self, input, rois, offset): assert input.shape[1] == self.output_dim if self.no_trans: offset = input.new() - return dcn_v2_pooling(input, rois, offset, - self.spatial_scale, - self.pooled_size, - self.output_dim, - self.no_trans, - self.group_size, - self.part_size, - self.sample_per_part, - self.trans_std) + return dcn_v2_pooling( + input, + rois, + offset, + self.spatial_scale, + self.pooled_size, + self.output_dim, + self.no_trans, + self.group_size, + self.part_size, + self.sample_per_part, + self.trans_std, + ) class DCNPooling(DCNv2Pooling): - - def __init__(self, - spatial_scale, - pooled_size, - output_dim, - no_trans, - group_size=1, - part_size=None, - sample_per_part=4, - trans_std=.0, - deform_fc_dim=1024): - super(DCNPooling, self).__init__(spatial_scale, - pooled_size, - output_dim, - no_trans, - group_size, - part_size, - sample_per_part, - trans_std) + def __init__( + self, + spatial_scale, + pooled_size, + output_dim, + no_trans, + group_size=1, + part_size=None, + sample_per_part=4, + trans_std=0.0, + deform_fc_dim=1024, + ): + super(DCNPooling, self).__init__( + spatial_scale, + pooled_size, + output_dim, + no_trans, + group_size, + part_size, + sample_per_part, + trans_std, + ) self.deform_fc_dim = deform_fc_dim if not no_trans: self.offset_mask_fc = nn.Sequential( - nn.Linear(self.pooled_size * self.pooled_size * - self.output_dim, self.deform_fc_dim), + nn.Linear( + self.pooled_size * self.pooled_size * self.output_dim, self.deform_fc_dim + ), nn.ReLU(inplace=True), nn.Linear(self.deform_fc_dim, self.deform_fc_dim), nn.ReLU(inplace=True), - nn.Linear(self.deform_fc_dim, self.pooled_size * - self.pooled_size * 3) + nn.Linear(self.deform_fc_dim, self.pooled_size * self.pooled_size * 3), ) self.offset_mask_fc[4].weight.data.zero_() self.offset_mask_fc[4].bias.data.zero_() @@ -270,41 +346,55 @@ def forward(self, input, rois): # do roi_align first n = rois.shape[0] - roi = dcn_v2_pooling(input, rois, offset, - self.spatial_scale, - self.pooled_size, - self.output_dim, - True, # no trans - self.group_size, - self.part_size, - self.sample_per_part, - self.trans_std) + roi = dcn_v2_pooling( + input, + rois, + offset, + self.spatial_scale, + self.pooled_size, + self.output_dim, + True, # no trans + self.group_size, + self.part_size, + self.sample_per_part, + self.trans_std, + ) # build mask and offset offset_mask = self.offset_mask_fc(roi.view(n, -1)) - offset_mask = offset_mask.view( - n, 3, self.pooled_size, self.pooled_size) + offset_mask = offset_mask.view(n, 3, self.pooled_size, self.pooled_size) o1, o2, mask = torch.chunk(offset_mask, 3, dim=1) offset = torch.cat((o1, o2), dim=1) mask = torch.sigmoid(mask) # do pooling with offset and mask - return dcn_v2_pooling(input, rois, offset, - self.spatial_scale, - self.pooled_size, - self.output_dim, - self.no_trans, - self.group_size, - self.part_size, - self.sample_per_part, - self.trans_std) * mask + return ( + dcn_v2_pooling( + input, + rois, + offset, + self.spatial_scale, + self.pooled_size, + self.output_dim, + self.no_trans, + self.group_size, + self.part_size, + self.sample_per_part, + self.trans_std, + ) + * mask + ) # only roi_align - return dcn_v2_pooling(input, rois, offset, - self.spatial_scale, - self.pooled_size, - self.output_dim, - self.no_trans, - self.group_size, - self.part_size, - self.sample_per_part, - self.trans_std) + return dcn_v2_pooling( + input, + rois, + offset, + self.spatial_scale, + self.pooled_size, + self.output_dim, + self.no_trans, + self.group_size, + self.part_size, + self.sample_per_part, + self.trans_std, + ) diff --git a/setup.py b/setup.py index 1082494..7ee6edb 100644 --- a/setup.py +++ b/setup.py @@ -1,19 +1,15 @@ #!/usr/bin/env python -import os import glob +import os import torch - -from torch.utils.cpp_extension import CUDA_HOME -from torch.utils.cpp_extension import CppExtension -from torch.utils.cpp_extension import CUDAExtension - -from setuptools import find_packages -from setuptools import setup +from setuptools import find_packages, setup +from torch.utils.cpp_extension import CUDA_HOME, CppExtension, CUDAExtension requirements = ["torch", "torchvision"] + def get_extensions(): this_dir = os.path.dirname(os.path.abspath(__file__)) extensions_dir = os.path.join(this_dir, "src") @@ -22,6 +18,7 @@ def get_extensions(): source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp")) source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu")) + os.environ["CC"] = "g++" sources = main_file + source_cpu extension = CppExtension extra_compile_args = {"cxx": []} @@ -38,7 +35,8 @@ def get_extensions(): "-D__CUDA_NO_HALF2_OPERATORS__", ] else: - raise NotImplementedError('Cuda is not availabel') + # raise NotImplementedError('Cuda is not available') + pass sources = [os.path.join(extensions_dir, s) for s in sources] include_dirs = [extensions_dir] @@ -53,14 +51,15 @@ def get_extensions(): ] return ext_modules + setup( name="DCNv2", version="0.1", author="charlesshang", url="https://github.com/charlesshang/DCNv2", description="deformable convolutional networks", - packages=find_packages(exclude=("configs", "tests",)), + packages=find_packages(exclude=("configs", "tests")), # install_requires=requirements, ext_modules=get_extensions(), cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, -) \ No newline at end of file +) diff --git a/src/cpu/dcn_v2_cpu.cpp b/src/cpu/dcn_v2_cpu.cpp index a68ccef..76c65f0 100644 --- a/src/cpu/dcn_v2_cpu.cpp +++ b/src/cpu/dcn_v2_cpu.cpp @@ -1,74 +1,233 @@ #include +#include "cpu/dcn_v2_im2col_cpu.h" #include -#include +//#include +#include +//#include +//#include + +//extern THCState *state; + +// author: Charles Shang +// https://github.com/torch/cunn/blob/master/lib/THCUNN/generic/SpatialConvolutionMM.cu +// modified from the CUDA version for CPU use by Daniel K. Suhendro at::Tensor dcn_v2_cpu_forward(const at::Tensor &input, - const at::Tensor &weight, - const at::Tensor &bias, - const at::Tensor &offset, - const at::Tensor &mask, - const int kernel_h, - const int kernel_w, - const int stride_h, - const int stride_w, - const int pad_h, - const int pad_w, - const int dilation_h, - const int dilation_w, - const int deformable_group) -{ - AT_ERROR("Not implement on cpu"); -} - -std::vector -dcn_v2_cpu_backward(const at::Tensor &input, const at::Tensor &weight, const at::Tensor &bias, const at::Tensor &offset, const at::Tensor &mask, - const at::Tensor &grad_output, - int kernel_h, int kernel_w, - int stride_h, int stride_w, - int pad_h, int pad_w, - int dilation_h, int dilation_w, - int deformable_group) + const int kernel_h, + const int kernel_w, + const int stride_h, + const int stride_w, + const int pad_h, + const int pad_w, + const int dilation_h, + const int dilation_w, + const int deformable_group) { - AT_ERROR("Not implement on cpu"); -} + // THCAssertSameGPU(THCudaTensor_checkGPU(state, 5, input, weight, bias, offset, mask)); + /*AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor"); + AT_ASSERTM(weight.type().is_cuda(), "weight must be a CUDA tensor"); + AT_ASSERTM(bias.type().is_cuda(), "bias must be a CUDA tensor"); + AT_ASSERTM(offset.type().is_cuda(), "offset must be a CUDA tensor"); + AT_ASSERTM(mask.type().is_cuda(), "mask must be a CUDA tensor");*/ -std::tuple -dcn_v2_psroi_pooling_cpu_forward(const at::Tensor &input, - const at::Tensor &bbox, - const at::Tensor &trans, - const int no_trans, - const float spatial_scale, - const int output_dim, - const int group_size, - const int pooled_size, - const int part_size, - const int sample_per_part, - const float trans_std) -{ - AT_ERROR("Not implement on cpu"); + const int batch = input.size(0); + const int channels = input.size(1); + const int height = input.size(2); + const int width = input.size(3); + + const int channels_out = weight.size(0); + const int channels_kernel = weight.size(1); + const int kernel_h_ = weight.size(2); + const int kernel_w_ = weight.size(3); + + // printf("Kernels: %d %d %d %d\n", kernel_h_, kernel_w_, kernel_w, kernel_h); + // printf("Channels: %d %d\n", channels, channels_kernel); + // printf("Channels: %d %d\n", channels_out, channels_kernel); + + AT_ASSERTM(kernel_h_ == kernel_h && kernel_w_ == kernel_w, + "Input shape and kernel shape wont match: (%d x %d vs %d x %d).", kernel_h_, kernel_w, kernel_h_, kernel_w_); + + AT_ASSERTM(channels == channels_kernel, + "Input shape and kernel channels wont match: (%d vs %d).", channels, channels_kernel); + + const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; + const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; + + auto ones = at::ones({height_out, width_out}, input.options()); + auto columns = at::empty({channels * kernel_h * kernel_w, 1 * height_out * width_out}, input.options()); + auto output = at::empty({batch, channels_out, height_out, width_out}, input.options()); + + using scalar_t = float; + for (int b = 0; b < batch; b++) + { + auto input_n = input.select(0, b); + auto offset_n = offset.select(0, b); + auto mask_n = mask.select(0, b); + auto output_n = output.select(0, b); + + // Do Bias first: + // M,N,K are dims of matrix A and B + // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm) + // (N x 1) (1 x M) + long m_ = channels_out; + long n_ = height_out * width_out; + long k_ = 1; + THFloatBlas_gemm('t', 'n', n_, m_, k_, 1.0f, + ones.contiguous().data(), k_, + bias.contiguous().data(), k_, 0.0f, + output_n.data(), n_); + + modulated_deformable_im2col_cpu(input_n.data(), + offset_n.data(), + mask_n.data(), + 1, channels, height, width, + height_out, width_out, kernel_h, kernel_w, + pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + deformable_group, + columns.data()); + + //(k * m) x (m * n) + // Y = WC + long m = channels_out; + long n = height_out * width_out; + long k = channels * kernel_h * kernel_w; + THFloatBlas_gemm('n', 'n', n, m, k, 1.0f, + columns.data(), n, + weight.data(), k, 1.0f, + output_n.data(), n); + } + return output; } -std::tuple -dcn_v2_psroi_pooling_cpu_backward(const at::Tensor &out_grad, - const at::Tensor &input, - const at::Tensor &bbox, - const at::Tensor &trans, - const at::Tensor &top_count, - const int no_trans, - const float spatial_scale, - const int output_dim, - const int group_size, - const int pooled_size, - const int part_size, - const int sample_per_part, - const float trans_std) +std::vector dcn_v2_cpu_backward(const at::Tensor &input, + const at::Tensor &weight, + const at::Tensor &bias, + const at::Tensor &offset, + const at::Tensor &mask, + const at::Tensor &grad_output, + int kernel_h, int kernel_w, + int stride_h, int stride_w, + int pad_h, int pad_w, + int dilation_h, int dilation_w, + int deformable_group) { - AT_ERROR("Not implement on cpu"); + + THArgCheck(input.is_contiguous(), 1, "input tensor has to be contiguous"); + THArgCheck(weight.is_contiguous(), 2, "weight tensor has to be contiguous"); + + /*AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor"); + AT_ASSERTM(weight.type().is_cuda(), "weight must be a CUDA tensor"); + AT_ASSERTM(bias.type().is_cuda(), "bias must be a CUDA tensor"); + AT_ASSERTM(offset.type().is_cuda(), "offset must be a CUDA tensor"); + AT_ASSERTM(mask.type().is_cuda(), "mask must be a CUDA tensor");*/ + + const int batch = input.size(0); + const int channels = input.size(1); + const int height = input.size(2); + const int width = input.size(3); + + const int channels_out = weight.size(0); + const int channels_kernel = weight.size(1); + const int kernel_h_ = weight.size(2); + const int kernel_w_ = weight.size(3); + + AT_ASSERTM(kernel_h_ == kernel_h && kernel_w_ == kernel_w, + "Input shape and kernel shape wont match: (%d x %d vs %d x %d).", kernel_h_, kernel_w, kernel_h_, kernel_w_); + + AT_ASSERTM(channels == channels_kernel, + "Input shape and kernel channels wont match: (%d vs %d).", channels, channels_kernel); + + const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; + const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; + + auto ones = at::ones({height_out, width_out}, input.options()); + auto columns = at::empty({channels * kernel_h * kernel_w, 1 * height_out * width_out}, input.options()); + auto output = at::empty({batch, channels_out, height_out, width_out}, input.options()); + + auto grad_input = at::zeros_like(input); + auto grad_weight = at::zeros_like(weight); + auto grad_bias = at::zeros_like(bias); + auto grad_offset = at::zeros_like(offset); + auto grad_mask = at::zeros_like(mask); + + using scalar_t = float; + + for (int b = 0; b < batch; b++) + { + auto input_n = input.select(0, b); + auto offset_n = offset.select(0, b); + auto mask_n = mask.select(0, b); + auto grad_output_n = grad_output.select(0, b); + auto grad_input_n = grad_input.select(0, b); + auto grad_offset_n = grad_offset.select(0, b); + auto grad_mask_n = grad_mask.select(0, b); + + long m = channels * kernel_h * kernel_w; + long n = height_out * width_out; + long k = channels_out; + + THFloatBlas_gemm('n', 't', n, m, k, 1.0f, + grad_output_n.data(), n, + weight.data(), m, 0.0f, + columns.data(), n); + + // gradient w.r.t. input coordinate data + modulated_deformable_col2im_coord_cpu(columns.data(), + input_n.data(), + offset_n.data(), + mask_n.data(), + 1, channels, height, width, + height_out, width_out, kernel_h, kernel_w, + pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deformable_group, + grad_offset_n.data(), + grad_mask_n.data()); + // gradient w.r.t. input data + modulated_deformable_col2im_cpu(columns.data(), + offset_n.data(), + mask_n.data(), + 1, channels, height, width, + height_out, width_out, kernel_h, kernel_w, + pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deformable_group, + grad_input_n.data()); + + // gradient w.r.t. weight, dWeight should accumulate across the batch and group + modulated_deformable_im2col_cpu(input_n.data(), + offset_n.data(), + mask_n.data(), + 1, channels, height, width, + height_out, width_out, kernel_h, kernel_w, + pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deformable_group, + columns.data()); + + long m_ = channels_out; + long n_ = channels * kernel_h * kernel_w; + long k_ = height_out * width_out; + + THFloatBlas_gemm('t', 'n', n_, m_, k_, 1.0f, + columns.data(), k_, + grad_output_n.data(), k_, 1.0f, + grad_weight.data(), n_); + + // gradient w.r.t. bias + // long m_ = channels_out; + // long k__ = height_out * width_out; + // THFloatBlas_gemv('t', k_, m_, 1.0f, + // grad_output_n.data(), k_, + // ones.data(), 1, 1.0f, + // grad_bias.data(), 1); + } + + return { + grad_input, grad_offset, grad_mask, grad_weight, grad_bias + }; } \ No newline at end of file diff --git a/src/cpu/dcn_v2_im2col_cpu.cpp b/src/cpu/dcn_v2_im2col_cpu.cpp new file mode 100644 index 0000000..1704a60 --- /dev/null +++ b/src/cpu/dcn_v2_im2col_cpu.cpp @@ -0,0 +1,395 @@ +#include "dcn_v2_im2col_cpu.h" +#include +#include +#include + +#include +//#include + +#include +//#include +//#include + +// modified from the CUDA version for CPU use by Daniel K. Suhendro + +/*#define CUDA_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ + i < (n); \ + i += blockDim.x * gridDim.x) + +const int CUDA_NUM_THREADS = 1024; +inline int GET_BLOCKS(const int N) +{ + return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; +}*/ + + +float dmcn_im2col_bilinear_cpu(const float *bottom_data, const int data_width, + const int height, const int width, float h, float w) +{ + int h_low = floor(h); + int w_low = floor(w); + int h_high = h_low + 1; + int w_high = w_low + 1; + + float lh = h - h_low; + float lw = w - w_low; + float hh = 1 - lh, hw = 1 - lw; + + float v1 = 0; + if (h_low >= 0 && w_low >= 0) + v1 = bottom_data[h_low * data_width + w_low]; + float v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + v2 = bottom_data[h_low * data_width + w_high]; + float v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + v3 = bottom_data[h_high * data_width + w_low]; + float v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + v4 = bottom_data[h_high * data_width + w_high]; + + float w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + + float val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; +} + +float dmcn_get_gradient_weight_cpu(float argmax_h, float argmax_w, + const int h, const int w, const int height, const int width) +{ + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) + { + //empty + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + float weight = 0; + if (h == argmax_h_low && w == argmax_w_low) + weight = (h + 1 - argmax_h) * (w + 1 - argmax_w); + if (h == argmax_h_low && w == argmax_w_high) + weight = (h + 1 - argmax_h) * (argmax_w + 1 - w); + if (h == argmax_h_high && w == argmax_w_low) + weight = (argmax_h + 1 - h) * (w + 1 - argmax_w); + if (h == argmax_h_high && w == argmax_w_high) + weight = (argmax_h + 1 - h) * (argmax_w + 1 - w); + return weight; +} + +float dmcn_get_coordinate_weight_cpu(float argmax_h, float argmax_w, + const int height, const int width, const float *im_data, + const int data_width, const int bp_dir) +{ + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) + { + //empty + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + float weight = 0; + + if (bp_dir == 0) + { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high]; + } + else if (bp_dir == 1) + { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high]; + } + + return weight; +} + +void modulated_deformable_im2col_cpu_kernel(const int n, const float *data_im, const float *data_offset, const float *data_mask, + const int height, const int width, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int num_channels, const int deformable_group, + const int height_col, const int width_col, + float *data_col) +{ + // launch channels * batch_size * height_col * width_col cores + for(int index=0; index(0); + const float h_im = h_in + i * dilation_h + offset_h; + const float w_im = w_in + j * dilation_w + offset_w; + //if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) { + if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) + { + //const float map_h = i * dilation_h + offset_h; + //const float map_w = j * dilation_w + offset_w; + //const int cur_height = height - h_in; + //const int cur_width = width - w_in; + //val = dmcn_im2col_bilinear_cpu(data_im_ptr, width, cur_height, cur_width, map_h, map_w); + val = dmcn_im2col_bilinear_cpu(data_im_ptr, width, height, width, h_im, w_im); + } + *data_col_ptr = val * mask; + // data_col_ptr += batch_size * height_col * width_col; + data_col_ptr += height_col * width_col; + } + } + } +} + +void modulated_deformable_col2im_cpu_kernel(const int n, const float *data_col, const float *data_offset, const float *data_mask, + const int channels, const int height, const int width, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int deformable_group, + const int height_col, const int width_col, + float *grad_im) +{ + for(int index = 0; index < n; index++) + { + const int j = (index / width_col / height_col / batch_size) % kernel_w; + const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h; + const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h; + // compute the start and end of the output + + const int deformable_group_index = c / channel_per_deformable_group; + + int w_out = index % width_col; + int h_out = (index / width_col) % height_col; + int b = (index / width_col / height_col) % batch_size; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + + const float *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; + const float *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; + const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out; + const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out; + const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_out) * width_col + w_out; + const float offset_h = data_offset_ptr[data_offset_h_ptr]; + const float offset_w = data_offset_ptr[data_offset_w_ptr]; + const float mask = data_mask_ptr[data_mask_hw_ptr]; + const float cur_inv_h_data = h_in + i * dilation_h + offset_h; + const float cur_inv_w_data = w_in + j * dilation_w + offset_w; + + const float cur_top_grad = data_col[index] * mask; + const int cur_h = (int)cur_inv_h_data; + const int cur_w = (int)cur_inv_w_data; + + for (int dy = -2; dy <= 2; dy++) + { + for (int dx = -2; dx <= 2; dx++) + { + if (cur_h + dy >= 0 && cur_h + dy < height && + cur_w + dx >= 0 && cur_w + dx < width && + abs(cur_inv_h_data - (cur_h + dy)) < 1 && + abs(cur_inv_w_data - (cur_w + dx)) < 1) + { + int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx; + float weight = dmcn_get_gradient_weight_cpu(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width); + //atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad); + *(grad_im + cur_bottom_grad_pos) += weight * cur_top_grad; + + } + } + } + } +} + +void modulated_deformable_col2im_coord_cpu_kernel(const int n, const float *data_col, const float *data_im, + const float *data_offset, const float *data_mask, + const int channels, const int height, const int width, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int offset_channels, const int deformable_group, + const int height_col, const int width_col, + float *grad_offset, float *grad_mask) +{ + for(int index = 0; index < n; index++) + { + float val = 0, mval = 0; + int w = index % width_col; + int h = (index / width_col) % height_col; + int c = (index / width_col / height_col) % offset_channels; + int b = (index / width_col / height_col) / offset_channels; + // compute the start and end of the output + + const int deformable_group_index = c / (2 * kernel_h * kernel_w); + const int col_step = kernel_h * kernel_w; + int cnt = 0; + const float *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * batch_size * width_col * height_col; + const float *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) * channel_per_deformable_group / kernel_h / kernel_w * height * width; + const float *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; + const float *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; + + const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w; + + for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step) + { + const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w; + const int bp_dir = offset_c % 2; + + int j = (col_pos / width_col / height_col / batch_size) % kernel_w; + int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h; + int w_out = col_pos % width_col; + int h_out = (col_pos / width_col) % height_col; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out); + const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out); + const int data_mask_hw_ptr = (((i * kernel_w + j) * height_col + h_out) * width_col + w_out); + const float offset_h = data_offset_ptr[data_offset_h_ptr]; + const float offset_w = data_offset_ptr[data_offset_w_ptr]; + const float mask = data_mask_ptr[data_mask_hw_ptr]; + float inv_h = h_in + i * dilation_h + offset_h; + float inv_w = w_in + j * dilation_w + offset_w; + if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width) + { + inv_h = inv_w = -2; + } + else + { + mval += data_col_ptr[col_pos] * dmcn_im2col_bilinear_cpu(data_im_ptr + cnt * height * width, width, height, width, inv_h, inv_w); + } + const float weight = dmcn_get_coordinate_weight_cpu( + inv_h, inv_w, + height, width, data_im_ptr + cnt * height * width, width, bp_dir); + val += weight * data_col_ptr[col_pos] * mask; + cnt += 1; + } + // KERNEL_ASSIGN(grad_offset[index], offset_req, val); + grad_offset[index] = val; + if (offset_c % 2 == 0) + // KERNEL_ASSIGN(grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w], mask_req, mval); + grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w] = mval; + } +} + +void modulated_deformable_im2col_cpu(const float* data_im, const float* data_offset, const float* data_mask, + const int batch_size, const int channels, const int height_im, const int width_im, + const int height_col, const int width_col, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group, float* data_col) { + // num_axes should be smaller than block size + const int channel_per_deformable_group = channels / deformable_group; + const int num_kernels = channels * batch_size * height_col * width_col; + modulated_deformable_im2col_cpu_kernel( + num_kernels, data_im, data_offset, data_mask, height_im, width_im, kernel_h, kernel_w, + pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group, + batch_size, channels, deformable_group, height_col, width_col, data_col); + + /*cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in modulated_deformable_im2col_cuda: %s\n", cudaGetErrorString(err)); + }*/ + +} + +void modulated_deformable_col2im_cpu(const float* data_col, const float* data_offset, const float* data_mask, + const int batch_size, const int channels, const int height_im, const int width_im, + const int height_col, const int width_col, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group, float* grad_im){ + + const int channel_per_deformable_group = channels / deformable_group; + const int num_kernels = channels * kernel_h * kernel_w * batch_size * height_col * width_col; + modulated_deformable_col2im_cpu_kernel( + num_kernels, data_col, data_offset, data_mask, channels, height_im, width_im, + kernel_h, kernel_w, pad_h, pad_h, stride_h, stride_w, + dilation_h, dilation_w, channel_per_deformable_group, + batch_size, deformable_group, height_col, width_col, grad_im); + /*cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in modulated_deformable_col2im_cuda: %s\n", cudaGetErrorString(err)); + }*/ + +} + +void modulated_deformable_col2im_coord_cpu(const float* data_col, const float* data_im, const float* data_offset, const float* data_mask, + const int batch_size, const int channels, const int height_im, const int width_im, + const int height_col, const int width_col, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group, + float* grad_offset, float* grad_mask) { + const int num_kernels = batch_size * height_col * width_col * 2 * kernel_h * kernel_w * deformable_group; + const int channel_per_deformable_group = channels * kernel_h * kernel_w / deformable_group; + modulated_deformable_col2im_coord_cpu_kernel( + num_kernels, data_col, data_im, data_offset, data_mask, channels, height_im, width_im, + kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, channel_per_deformable_group, + batch_size, 2 * kernel_h * kernel_w * deformable_group, deformable_group, height_col, width_col, + grad_offset, grad_mask); + /*cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in modulated_deformable_col2im_coord_cuda: %s\n", cudaGetErrorString(err)); + }*/ +} \ No newline at end of file diff --git a/src/cpu/dcn_v2_im2col_cpu.h b/src/cpu/dcn_v2_im2col_cpu.h new file mode 100644 index 0000000..bad5c52 --- /dev/null +++ b/src/cpu/dcn_v2_im2col_cpu.h @@ -0,0 +1,99 @@ + +/*! + ******************* BEGIN Caffe Copyright Notice and Disclaimer **************** + * + * COPYRIGHT + * + * All contributions by the University of California: + * Copyright (c) 2014-2017 The Regents of the University of California (Regents) + * All rights reserved. + * + * All other contributions: + * Copyright (c) 2014-2017, the respective contributors + * All rights reserved. + * + * Caffe uses a shared copyright model: each contributor holds copyright over + * their contributions to Caffe. The project versioning records all such + * contribution and copyright details. If a contributor wants to further mark + * their specific copyright on a particular contribution, they should indicate + * their copyright solely in the commit message of the change when it is + * committed. + * + * LICENSE + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR + * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * CONTRIBUTION AGREEMENT + * + * By contributing to the BVLC/caffe repository through pull-request, comment, + * or otherwise, the contributor releases their content to the + * license and copyright terms herein. + * + ***************** END Caffe Copyright Notice and Disclaimer ******************** + * + * Copyright (c) 2018 Microsoft + * Licensed under The MIT License [see LICENSE for details] + * \file modulated_deformable_im2col.h + * \brief Function definitions of converting an image to + * column matrix based on kernel, padding, dilation, and offset. + * These functions are mainly used in deformable convolution operators. + * \ref: https://arxiv.org/abs/1811.11168 + * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu + */ + +/***************** Adapted by Charles Shang *********************/ +// modified from the CUDA version for CPU use by Daniel K. Suhendro + +#ifndef DCN_V2_IM2COL_CPU +#define DCN_V2_IM2COL_CPU + +#ifdef __cplusplus +extern "C" +{ +#endif + + void modulated_deformable_im2col_cpu(const float *data_im, const float *data_offset, const float *data_mask, + const int batch_size, const int channels, const int height_im, const int width_im, + const int height_col, const int width_col, const int kernel_h, const int kenerl_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group, float *data_col); + + void modulated_deformable_col2im_cpu(const float *data_col, const float *data_offset, const float *data_mask, + const int batch_size, const int channels, const int height_im, const int width_im, + const int height_col, const int width_col, const int kernel_h, const int kenerl_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group, float *grad_im); + + void modulated_deformable_col2im_coord_cpu(const float *data_col, const float *data_im, const float *data_offset, const float *data_mask, + const int batch_size, const int channels, const int height_im, const int width_im, + const int height_col, const int width_col, const int kernel_h, const int kenerl_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group, + float *grad_offset, float *grad_mask); + +#ifdef __cplusplus +} +#endif + +#endif \ No newline at end of file diff --git a/src/cpu/dcn_v2_psroi_pooling_cpu.cpp b/src/cpu/dcn_v2_psroi_pooling_cpu.cpp new file mode 100644 index 0000000..6e41aae --- /dev/null +++ b/src/cpu/dcn_v2_psroi_pooling_cpu.cpp @@ -0,0 +1,426 @@ +/*! + * Copyright (c) 2017 Microsoft + * Licensed under The MIT License [see LICENSE for details] + * \file deformable_psroi_pooling.cu + * \brief + * \author Yi Li, Guodong Zhang, Jifeng Dai +*/ +/***************** Adapted by Charles Shang *********************/ +// modified from the CUDA version for CPU use by Daniel K. Suhendro + +#include +#include +#include + +#include +//#include + +#include +//#include +//#include + +/*#define CUDA_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ + i < (n); \ + i += blockDim.x * gridDim.x) + +const int CUDA_NUM_THREADS = 1024; +inline int GET_BLOCKS(const int N) +{ + return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; +}*/ + +template +T bilinear_interp_cpu( + const T *data, + const T x, + const T y, + const int width, + const int height) +{ + int x1 = floor(x); + int x2 = ceil(x); + int y1 = floor(y); + int y2 = ceil(y); + T dist_x = static_cast(x - x1); + T dist_y = static_cast(y - y1); + T value11 = data[y1 * width + x1]; + T value12 = data[y2 * width + x1]; + T value21 = data[y1 * width + x2]; + T value22 = data[y2 * width + x2]; + T value = (1 - dist_x) * (1 - dist_y) * value11 + + (1 - dist_x) * dist_y * value12 + + dist_x * (1 - dist_y) * value21 + + dist_x * dist_y * value22; + return value; +} + +template + void DeformablePSROIPoolForwardKernelCpu( + const int count, + const T *bottom_data, + const T spatial_scale, + const int channels, + const int height, const int width, + const int pooled_height, const int pooled_width, + const T *bottom_rois, const T *bottom_trans, + const int no_trans, + const T trans_std, + const int sample_per_part, + const int output_dim, + const int group_size, + const int part_size, + const int num_classes, + const int channels_each_class, + T *top_data, + T *top_count) +{ + for(int index = 0; index < count; index++) + { + // The output is in order (n, ctop, ph, pw) + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int ctop = (index / pooled_width / pooled_height) % output_dim; + int n = index / pooled_width / pooled_height / output_dim; + + // [start, end) interval for spatial sampling + const T *offset_bottom_rois = bottom_rois + n * 5; + int roi_batch_ind = offset_bottom_rois[0]; + T roi_start_w = static_cast(round(offset_bottom_rois[1])) * spatial_scale - 0.5; + T roi_start_h = static_cast(round(offset_bottom_rois[2])) * spatial_scale - 0.5; + T roi_end_w = static_cast(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5; + T roi_end_h = static_cast(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5; + + // Force too small ROIs to be 1x1 + T roi_width = std::max(roi_end_w - roi_start_w, T(0.1)); //avoid 0 + T roi_height = std::max(roi_end_h - roi_start_h, T(0.1)); + + // Compute w and h at bottom + T bin_size_h = roi_height / static_cast(pooled_height); + T bin_size_w = roi_width / static_cast(pooled_width); + + T sub_bin_size_h = bin_size_h / static_cast(sample_per_part); + T sub_bin_size_w = bin_size_w / static_cast(sample_per_part); + + int part_h = floor(static_cast(ph) / pooled_height * part_size); + int part_w = floor(static_cast(pw) / pooled_width * part_size); + int class_id = ctop / channels_each_class; + T trans_x = no_trans ? static_cast(0) : bottom_trans[(((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w] * trans_std; + T trans_y = no_trans ? static_cast(0) : bottom_trans[(((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w] * trans_std; + + T wstart = static_cast(pw) * bin_size_w + roi_start_w; + wstart += trans_x * roi_width; + T hstart = static_cast(ph) * bin_size_h + roi_start_h; + hstart += trans_y * roi_height; + + T sum = 0; + int count = 0; + int gw = floor(static_cast(pw) * group_size / pooled_width); + int gh = floor(static_cast(ph) * group_size / pooled_height); + gw = std::min(std::max(gw, 0), group_size - 1); + gh = std::min(std::max(gh, 0), group_size - 1); + + const T *offset_bottom_data = bottom_data + (roi_batch_ind * channels) * height * width; + for (int ih = 0; ih < sample_per_part; ih++) + { + for (int iw = 0; iw < sample_per_part; iw++) + { + T w = wstart + iw * sub_bin_size_w; + T h = hstart + ih * sub_bin_size_h; + // bilinear interpolation + if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5) + { + continue; + } + w = std::min(std::max(w, T(0.)), width - T(1.)); + h = std::min(std::max(h, T(0.)), height - T(1.)); + int c = (ctop * group_size + gh) * group_size + gw; + T val = bilinear_interp_cpu(offset_bottom_data + c * height * width, w, h, width, height); + sum += val; + count++; + } + } + top_data[index] = count == 0 ? static_cast(0) : sum / count; + top_count[index] = count; + } +} + +template +void DeformablePSROIPoolBackwardAccKernelCpu( + const int count, + const T *top_diff, + const T *top_count, + const int num_rois, + const T spatial_scale, + const int channels, + const int height, const int width, + const int pooled_height, const int pooled_width, + const int output_dim, + T *bottom_data_diff, T *bottom_trans_diff, + const T *bottom_data, + const T *bottom_rois, + const T *bottom_trans, + const int no_trans, + const T trans_std, + const int sample_per_part, + const int group_size, + const int part_size, + const int num_classes, + const int channels_each_class) +{ + for(int index = 0; index < count; index++) + { + // The output is in order (n, ctop, ph, pw) + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int ctop = (index / pooled_width / pooled_height) % output_dim; + int n = index / pooled_width / pooled_height / output_dim; + + // [start, end) interval for spatial sampling + const T *offset_bottom_rois = bottom_rois + n * 5; + int roi_batch_ind = offset_bottom_rois[0]; + T roi_start_w = static_cast(round(offset_bottom_rois[1])) * spatial_scale - 0.5; + T roi_start_h = static_cast(round(offset_bottom_rois[2])) * spatial_scale - 0.5; + T roi_end_w = static_cast(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5; + T roi_end_h = static_cast(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5; + + // Force too small ROIs to be 1x1 + T roi_width = std::max(roi_end_w - roi_start_w, T(0.1)); //avoid 0 + T roi_height = std::max(roi_end_h - roi_start_h, T(0.1)); + + // Compute w and h at bottom + T bin_size_h = roi_height / static_cast(pooled_height); + T bin_size_w = roi_width / static_cast(pooled_width); + + T sub_bin_size_h = bin_size_h / static_cast(sample_per_part); + T sub_bin_size_w = bin_size_w / static_cast(sample_per_part); + + int part_h = floor(static_cast(ph) / pooled_height * part_size); + int part_w = floor(static_cast(pw) / pooled_width * part_size); + int class_id = ctop / channels_each_class; + T trans_x = no_trans ? static_cast(0) : bottom_trans[(((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w] * trans_std; + T trans_y = no_trans ? static_cast(0) : bottom_trans[(((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w] * trans_std; + + T wstart = static_cast(pw) * bin_size_w + roi_start_w; + wstart += trans_x * roi_width; + T hstart = static_cast(ph) * bin_size_h + roi_start_h; + hstart += trans_y * roi_height; + + if (top_count[index] <= 0) + { + continue; + } + T diff_val = top_diff[index] / top_count[index]; + const T *offset_bottom_data = bottom_data + roi_batch_ind * channels * height * width; + T *offset_bottom_data_diff = bottom_data_diff + roi_batch_ind * channels * height * width; + int gw = floor(static_cast(pw) * group_size / pooled_width); + int gh = floor(static_cast(ph) * group_size / pooled_height); + gw = std::min(std::max(gw, 0), group_size - 1); + gh = std::min(std::max(gh, 0), group_size - 1); + + for (int ih = 0; ih < sample_per_part; ih++) + { + for (int iw = 0; iw < sample_per_part; iw++) + { + T w = wstart + iw * sub_bin_size_w; + T h = hstart + ih * sub_bin_size_h; + // bilinear interpolation + if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5) + { + continue; + } + w = std::min(std::max(w, T(0.)), width - T(1.)); + h = std::min(std::max(h, T(0.)), height - T(1.)); + int c = (ctop * group_size + gh) * group_size + gw; + // backward on feature + int x0 = floor(w); + int x1 = ceil(w); + int y0 = floor(h); + int y1 = ceil(h); + T dist_x = w - x0, dist_y = h - y0; + T q00 = (1 - dist_x) * (1 - dist_y); + T q01 = (1 - dist_x) * dist_y; + T q10 = dist_x * (1 - dist_y); + T q11 = dist_x * dist_y; + int bottom_index_base = c * height * width; + /*atomicAdd(offset_bottom_data_diff + bottom_index_base + y0 * width + x0, q00 * diff_val); + atomicAdd(offset_bottom_data_diff + bottom_index_base + y1 * width + x0, q01 * diff_val); + atomicAdd(offset_bottom_data_diff + bottom_index_base + y0 * width + x1, q10 * diff_val); + atomicAdd(offset_bottom_data_diff + bottom_index_base + y1 * width + x1, q11 * diff_val);*/ + *(offset_bottom_data_diff + bottom_index_base + y0 * width + x0) += q00 * diff_val; + *(offset_bottom_data_diff + bottom_index_base + y1 * width + x0) += q01 * diff_val; + *(offset_bottom_data_diff + bottom_index_base + y0 * width + x1) += q10 * diff_val; + *(offset_bottom_data_diff + bottom_index_base + y1 * width + x1) += q11 * diff_val; + + + if (no_trans) + { + continue; + } + T U00 = offset_bottom_data[bottom_index_base + y0 * width + x0]; + T U01 = offset_bottom_data[bottom_index_base + y1 * width + x0]; + T U10 = offset_bottom_data[bottom_index_base + y0 * width + x1]; + T U11 = offset_bottom_data[bottom_index_base + y1 * width + x1]; + T diff_x = (U11 * dist_y + U10 * (1 - dist_y) - U01 * dist_y - U00 * (1 - dist_y)) * trans_std * diff_val; + diff_x *= roi_width; + T diff_y = (U11 * dist_x + U01 * (1 - dist_x) - U10 * dist_x - U00 * (1 - dist_x)) * trans_std * diff_val; + diff_y *= roi_height; + + /*atomicAdd(bottom_trans_diff + (((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w, diff_x); + atomicAdd(bottom_trans_diff + (((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w, diff_y);*/ + *(bottom_trans_diff + (((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w) += diff_x; + *(bottom_trans_diff + (((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w) += diff_y; + } + } + } +} + +std::tuple +dcn_v2_psroi_pooling_cpu_forward(const at::Tensor &input, + const at::Tensor &bbox, + const at::Tensor &trans, + const int no_trans, + const float spatial_scale, + const int output_dim, + const int group_size, + const int pooled_size, + const int part_size, + const int sample_per_part, + const float trans_std) +{ + /*AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor"); + AT_ASSERTM(bbox.type().is_cuda(), "rois must be a CUDA tensor"); + AT_ASSERTM(trans.type().is_cuda(), "trans must be a CUDA tensor");*/ + + const int batch = input.size(0); + const int channels = input.size(1); + const int height = input.size(2); + const int width = input.size(3); + const int channels_trans = no_trans ? 2 : trans.size(1); + const int num_bbox = bbox.size(0); + + AT_ASSERTM(channels == output_dim, "input channels and output channels must equal"); + auto pooled_height = pooled_size; + auto pooled_width = pooled_size; + + auto out = at::empty({num_bbox, output_dim, pooled_height, pooled_width}, input.options()); + long out_size = num_bbox * output_dim * pooled_height * pooled_width; + auto top_count = at::zeros({num_bbox, output_dim, pooled_height, pooled_width}, input.options()); + + const int num_classes = no_trans ? 1 : channels_trans / 2; + const int channels_each_class = no_trans ? output_dim : output_dim / num_classes; + + //cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + if (out.numel() == 0) + { + //THCudaCheck(cudaGetLastError()); + return std::make_tuple(out, top_count); + } + + /*dim3 grid(std::min(THCCeilDiv(out_size, 512L), 4096L)); + dim3 block(512);*/ + + AT_DISPATCH_FLOATING_TYPES(input.type(), "dcn_v2_psroi_pooling_cpu_forward", [&] { + DeformablePSROIPoolForwardKernelCpu( + out_size, + input.contiguous().data(), + spatial_scale, + channels, + height, width, + pooled_height, + pooled_width, + bbox.contiguous().data(), + trans.contiguous().data(), + no_trans, + trans_std, + sample_per_part, + output_dim, + group_size, + part_size, + num_classes, + channels_each_class, + out.data(), + top_count.data()); + }); + //THCudaCheck(cudaGetLastError()); + return std::make_tuple(out, top_count); +} + +std::tuple +dcn_v2_psroi_pooling_cpu_backward(const at::Tensor &out_grad, + const at::Tensor &input, + const at::Tensor &bbox, + const at::Tensor &trans, + const at::Tensor &top_count, + const int no_trans, + const float spatial_scale, + const int output_dim, + const int group_size, + const int pooled_size, + const int part_size, + const int sample_per_part, + const float trans_std) +{ + /*AT_ASSERTM(out_grad.type().is_cuda(), "out_grad must be a CUDA tensor"); + AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor"); + AT_ASSERTM(bbox.type().is_cuda(), "bbox must be a CUDA tensor"); + AT_ASSERTM(trans.type().is_cuda(), "trans must be a CUDA tensor"); + AT_ASSERTM(top_count.type().is_cuda(), "top_count must be a CUDA tensor");*/ + + const int batch = input.size(0); + const int channels = input.size(1); + const int height = input.size(2); + const int width = input.size(3); + const int channels_trans = no_trans ? 2 : trans.size(1); + const int num_bbox = bbox.size(0); + + AT_ASSERTM(channels == output_dim, "input channels and output channels must equal"); + auto pooled_height = pooled_size; + auto pooled_width = pooled_size; + long out_size = num_bbox * output_dim * pooled_height * pooled_width; + const int num_classes = no_trans ? 1 : channels_trans / 2; + const int channels_each_class = no_trans ? output_dim : output_dim / num_classes; + + auto input_grad = at::zeros({batch, channels, height, width}, out_grad.options()); + auto trans_grad = at::zeros_like(trans); + + if (input_grad.numel() == 0) + { + //THCudaCheck(cudaGetLastError()); + return std::make_tuple(input_grad, trans_grad); + } + + /*dim3 grid(std::min(THCCeilDiv(out_size, 512L), 4096L)); + dim3 block(512); + cudaStream_t stream = at::cuda::getCurrentCUDAStream();*/ + + AT_DISPATCH_FLOATING_TYPES(out_grad.type(), "dcn_v2_psroi_pooling_cpu_backward", [&] { + DeformablePSROIPoolBackwardAccKernelCpu( + out_size, + out_grad.contiguous().data(), + top_count.contiguous().data(), + num_bbox, + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + output_dim, + input_grad.contiguous().data(), + trans_grad.contiguous().data(), + input.contiguous().data(), + bbox.contiguous().data(), + trans.contiguous().data(), + no_trans, + trans_std, + sample_per_part, + group_size, + part_size, + num_classes, + channels_each_class); + }); + //THCudaCheck(cudaGetLastError()); + return std::make_tuple(input_grad, trans_grad); +} \ No newline at end of file diff --git a/src/cuda/dcn_v2_cuda.cu b/src/cuda/dcn_v2_cuda.cu index 767ed8f..c90ee04 100644 --- a/src/cuda/dcn_v2_cuda.cu +++ b/src/cuda/dcn_v2_cuda.cu @@ -3,12 +3,48 @@ #include #include - +#include +#include +#include #include #include #include +#include +#include + +THCState *state = at::globalContext().lazyInitCUDA(); + +static cublasOperation_t _cublasOpFromChar(char op) { + switch (op) { + case 'n': + case 'N': + return CUBLAS_OP_N; + case 't': + case 'T': + return CUBLAS_OP_T; + case 'c': + case 'C': + return CUBLAS_OP_C; + } + AT_ERROR( + "_cublasOpFromChar input should be 't', 'n' or 'c' but got `", op, "`"); + } + + static void _cublasAdjustLdLevel2(int64_t m, int64_t n, int64_t* lda) { + // Note: leading dimensions generally are checked that they are > 0 + // and at least as big the result requires (even if the value won't + // be used). + + // Q: Why does Level3 check trans but this doesn't? + // A: In level 2, the sizes (m, n) specify the size of A + // (independent of trans value). In level 3. the sizes (m, n, k) + // specify the sizes of op(A), op(B) where op depend on trans + // values. + if (n <= 1) + *lda = std::max(m, 1); + } + -extern THCState *state; // author: Charles Shang // https://github.com/torch/cunn/blob/master/lib/THCUNN/generic/SpatialConvolutionMM.cu @@ -104,16 +140,16 @@ dcn_v2_cuda_forward(const at::Tensor &input, const int block = 128; const int grid = (batch + block - 1) / block; - createBatchGemmBuffer<<>>( + createBatchGemmBuffer<<>>( input_b, output_b, columns_b, ones_b, weight_b, bias_b, - input.data(), - output.data(), - columns.data(), - ones.data(), - weight.data(), - bias.data(), + input.data_ptr(), + output.data_ptr(), + columns.data_ptr(), + ones.data_ptr(), + weight.data_ptr(), + bias.data_ptr(), channels * width * height, channels_out * width_out * height_out, channels * kernel_h * kernel_w * height_out * width_out, @@ -136,15 +172,15 @@ dcn_v2_cuda_forward(const at::Tensor &input, output_b, n_, batch); - modulated_deformable_im2col_cuda(THCState_getCurrentStream(state), - input.data(), - offset.data(), - mask.data(), + modulated_deformable_im2col_cuda(c10::cuda::getCurrentCUDAStream(), + input.data_ptr(), + offset.data_ptr(), + mask.data_ptr(), batch, channels, height, width, height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, deformable_group, - columns.data()); + columns.data_ptr()); long m = channels_out; long n = height_out * width_out; @@ -271,63 +307,63 @@ std::vector dcn_v2_cuda_backward(const at::Tensor &input, long k = channels_out; THCudaBlas_Sgemm(state, 'n', 't', n, m, k, 1.0f, - grad_output_n.data(), n, - weight.data(), m, 0.0f, - columns.data(), n); + grad_output_n.data_ptr(), n, + weight.data_ptr(), m, 0.0f, + columns.data_ptr(), n); // gradient w.r.t. input coordinate data - modulated_deformable_col2im_coord_cuda(THCState_getCurrentStream(state), - columns.data(), - input_n.data(), - offset_n.data(), - mask_n.data(), + modulated_deformable_col2im_coord_cuda(c10::cuda::getCurrentCUDAStream(), + columns.data_ptr(), + input_n.data_ptr(), + offset_n.data_ptr(), + mask_n.data_ptr(), 1, channels, height, width, height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, deformable_group, - grad_offset_n.data(), - grad_mask_n.data()); + grad_offset_n.data_ptr(), + grad_mask_n.data_ptr()); // gradient w.r.t. input data - modulated_deformable_col2im_cuda(THCState_getCurrentStream(state), - columns.data(), - offset_n.data(), - mask_n.data(), + modulated_deformable_col2im_cuda(c10::cuda::getCurrentCUDAStream(), + columns.data_ptr(), + offset_n.data_ptr(), + mask_n.data_ptr(), 1, channels, height, width, height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, deformable_group, - grad_input_n.data()); + grad_input_n.data_ptr()); // gradient w.r.t. weight, dWeight should accumulate across the batch and group - modulated_deformable_im2col_cuda(THCState_getCurrentStream(state), - input_n.data(), - offset_n.data(), - mask_n.data(), + modulated_deformable_im2col_cuda(c10::cuda::getCurrentCUDAStream(), + input_n.data_ptr(), + offset_n.data_ptr(), + mask_n.data_ptr(), 1, channels, height, width, height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, deformable_group, - columns.data()); + columns.data_ptr()); long m_ = channels_out; long n_ = channels * kernel_h * kernel_w; long k_ = height_out * width_out; THCudaBlas_Sgemm(state, 't', 'n', n_, m_, k_, 1.0f, - columns.data(), k_, - grad_output_n.data(), k_, 1.0f, - grad_weight.data(), n_); - - // gradient w.r.t. bias - // long m_ = channels_out; - // long k__ = height_out * width_out; - THCudaBlas_Sgemv(state, - 't', - k_, m_, 1.0f, - grad_output_n.data(), k_, - ones.data(), 1, 1.0f, - grad_bias.data(), 1); + columns.data_ptr(), k_, + grad_output_n.data_ptr(), k_, 1.0f, + grad_weight.data_ptr(), n_); + + cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); + cublasOperation_t op = _cublasOpFromChar('t'); + _cublasAdjustLdLevel2(k_, m_, &k_); + float* grad_output_n_float = grad_output_n.data_ptr(); + float* one_float = ones.data_ptr(); + float alpha = 1.0f; + float beta = 1.0f; + cublasSgemv(handle, op, k_, m_, &alpha, grad_output_n_float,k_, one_float,1, &beta, grad_bias.data_ptr(), 1); } + return { grad_input, grad_offset, grad_mask, grad_weight, grad_bias diff --git a/src/cuda/dcn_v2_im2col_cuda.cu b/src/cuda/dcn_v2_im2col_cuda.cu index 4183793..4140eac 100644 --- a/src/cuda/dcn_v2_im2col_cuda.cu +++ b/src/cuda/dcn_v2_im2col_cuda.cu @@ -22,7 +22,7 @@ inline int GET_BLOCKS(const int N) } -__device__ float dmcn_im2col_bilinear(const float *bottom_data, const int data_width, +__device__ float dmcn_im2col_bilinear_cuda(const float *bottom_data, const int data_width, const int height, const int width, float h, float w) { int h_low = floor(h); @@ -53,7 +53,7 @@ __device__ float dmcn_im2col_bilinear(const float *bottom_data, const int data_w return val; } -__device__ float dmcn_get_gradient_weight(float argmax_h, float argmax_w, +__device__ float dmcn_get_gradient_weight_cuda(float argmax_h, float argmax_w, const int h, const int w, const int height, const int width) { if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) @@ -79,7 +79,7 @@ __device__ float dmcn_get_gradient_weight(float argmax_h, float argmax_w, return weight; } -__device__ float dmcn_get_coordinate_weight(float argmax_h, float argmax_w, +__device__ float dmcn_get_coordinate_weight_cuda(float argmax_h, float argmax_w, const int height, const int width, const float *im_data, const int data_width, const int bp_dir) { @@ -183,8 +183,8 @@ __global__ void modulated_deformable_im2col_gpu_kernel(const int n, //const float map_w = j * dilation_w + offset_w; //const int cur_height = height - h_in; //const int cur_width = width - w_in; - //val = dmcn_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w); - val = dmcn_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im); + //val = dmcn_im2col_bilinear_cuda(data_im_ptr, width, cur_height, cur_width, map_h, map_w); + val = dmcn_im2col_bilinear_cuda(data_im_ptr, width, height, width, h_im, w_im); } *data_col_ptr = val * mask; // data_col_ptr += batch_size * height_col * width_col; @@ -245,7 +245,7 @@ __global__ void modulated_deformable_col2im_gpu_kernel(const int n, abs(cur_inv_w_data - (cur_w + dx)) < 1) { int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx; - float weight = dmcn_get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width); + float weight = dmcn_get_gradient_weight_cuda(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width); atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad); } } @@ -310,9 +310,9 @@ __global__ void modulated_deformable_col2im_coord_gpu_kernel(const int n, } else { - mval += data_col_ptr[col_pos] * dmcn_im2col_bilinear(data_im_ptr + cnt * height * width, width, height, width, inv_h, inv_w); + mval += data_col_ptr[col_pos] * dmcn_im2col_bilinear_cuda(data_im_ptr + cnt * height * width, width, height, width, inv_h, inv_w); } - const float weight = dmcn_get_coordinate_weight( + const float weight = dmcn_get_coordinate_weight_cuda( inv_h, inv_w, height, width, data_im_ptr + cnt * height * width, width, bp_dir); val += weight * data_col_ptr[col_pos] * mask; diff --git a/src/cuda/dcn_v2_psroi_pooling_cuda.cu b/src/cuda/dcn_v2_psroi_pooling_cuda.cu index 07b438e..bf99f0c 100644 --- a/src/cuda/dcn_v2_psroi_pooling_cuda.cu +++ b/src/cuda/dcn_v2_psroi_pooling_cuda.cu @@ -31,7 +31,7 @@ inline int GET_BLOCKS(const int N) } template -__device__ T bilinear_interp( +__device__ T bilinear_interp_cuda( const T *data, const T x, const T y, @@ -56,7 +56,7 @@ __device__ T bilinear_interp( } template -__global__ void DeformablePSROIPoolForwardKernel( +__global__ void DeformablePSROIPoolForwardKernelCuda( const int count, const T *bottom_data, const T spatial_scale, @@ -135,7 +135,7 @@ __global__ void DeformablePSROIPoolForwardKernel( w = min(max(w, 0.), width - 1.); h = min(max(h, 0.), height - 1.); int c = (ctop * group_size + gh) * group_size + gw; - T val = bilinear_interp(offset_bottom_data + c * height * width, w, h, width, height); + T val = bilinear_interp_cuda(offset_bottom_data + c * height * width, w, h, width, height); sum += val; count++; } @@ -146,7 +146,7 @@ __global__ void DeformablePSROIPoolForwardKernel( } template -__global__ void DeformablePSROIPoolBackwardAccKernel( +__global__ void DeformablePSROIPoolBackwardAccKernelCuda( const int count, const T *top_diff, const T *top_count, @@ -315,16 +315,16 @@ dcn_v2_psroi_pooling_cuda_forward(const at::Tensor &input, dim3 block(512); AT_DISPATCH_FLOATING_TYPES(input.type(), "dcn_v2_psroi_pooling_cuda_forward", [&] { - DeformablePSROIPoolForwardKernel<<>>( + DeformablePSROIPoolForwardKernelCuda<<>>( out_size, - input.contiguous().data(), + input.contiguous().data_ptr(), spatial_scale, channels, height, width, pooled_height, pooled_width, - bbox.contiguous().data(), - trans.contiguous().data(), + bbox.contiguous().data_ptr(), + trans.contiguous().data_ptr(), no_trans, trans_std, sample_per_part, @@ -333,8 +333,8 @@ dcn_v2_psroi_pooling_cuda_forward(const at::Tensor &input, part_size, num_classes, channels_each_class, - out.data(), - top_count.data()); + out.data_ptr(), + top_count.data_ptr()); }); THCudaCheck(cudaGetLastError()); return std::make_tuple(out, top_count); @@ -389,10 +389,10 @@ dcn_v2_psroi_pooling_cuda_backward(const at::Tensor &out_grad, cudaStream_t stream = at::cuda::getCurrentCUDAStream(); AT_DISPATCH_FLOATING_TYPES(out_grad.type(), "dcn_v2_psroi_pooling_cuda_backward", [&] { - DeformablePSROIPoolBackwardAccKernel<<>>( + DeformablePSROIPoolBackwardAccKernelCuda<<>>( out_size, - out_grad.contiguous().data(), - top_count.contiguous().data(), + out_grad.contiguous().data_ptr(), + top_count.contiguous().data_ptr(), num_bbox, spatial_scale, channels, @@ -401,11 +401,11 @@ dcn_v2_psroi_pooling_cuda_backward(const at::Tensor &out_grad, pooled_height, pooled_width, output_dim, - input_grad.contiguous().data(), - trans_grad.contiguous().data(), - input.contiguous().data(), - bbox.contiguous().data(), - trans.contiguous().data(), + input_grad.contiguous().data_ptr(), + trans_grad.contiguous().data_ptr(), + input.contiguous().data_ptr(), + bbox.contiguous().data_ptr(), + trans.contiguous().data_ptr(), no_trans, trans_std, sample_per_part, diff --git a/src/cuda/vision.h b/src/cuda/vision.h index e42a2a7..f3672b1 100644 --- a/src/cuda/vision.h +++ b/src/cuda/vision.h @@ -1,6 +1,6 @@ #pragma once #include - +#include at::Tensor dcn_v2_cuda_forward(const at::Tensor &input, const at::Tensor &weight, diff --git a/src/dcn_v2.h b/src/dcn_v2.h index 23f5caf..de670bf 100644 --- a/src/dcn_v2.h +++ b/src/dcn_v2.h @@ -35,7 +35,14 @@ dcn_v2_forward(const at::Tensor &input, AT_ERROR("Not compiled with GPU support"); #endif } - AT_ERROR("Not implemented on the CPU"); + else{ + return dcn_v2_cpu_forward(input, weight, bias, offset, mask, + kernel_h, kernel_w, + stride_h, stride_w, + pad_h, pad_w, + dilation_h, dilation_w, + deformable_group); + } } std::vector @@ -69,7 +76,19 @@ dcn_v2_backward(const at::Tensor &input, AT_ERROR("Not compiled with GPU support"); #endif } - AT_ERROR("Not implemented on the CPU"); + else{ + return dcn_v2_cpu_backward(input, + weight, + bias, + offset, + mask, + grad_output, + kernel_h, kernel_w, + stride_h, stride_w, + pad_h, pad_w, + dilation_h, dilation_w, + deformable_group); + } } std::tuple @@ -103,7 +122,19 @@ dcn_v2_psroi_pooling_forward(const at::Tensor &input, AT_ERROR("Not compiled with GPU support"); #endif } - AT_ERROR("Not implemented on the CPU"); + else{ + return dcn_v2_psroi_pooling_cpu_forward(input, + bbox, + trans, + no_trans, + spatial_scale, + output_dim, + group_size, + pooled_size, + part_size, + sample_per_part, + trans_std); + } } std::tuple @@ -141,5 +172,19 @@ dcn_v2_psroi_pooling_backward(const at::Tensor &out_grad, AT_ERROR("Not compiled with GPU support"); #endif } - AT_ERROR("Not implemented on the CPU"); + else{ + return dcn_v2_psroi_pooling_cpu_backward(out_grad, + input, + bbox, + trans, + top_count, + no_trans, + spatial_scale, + output_dim, + group_size, + pooled_size, + part_size, + sample_per_part, + trans_std); + } } \ No newline at end of file From 977dfd39b3536ba20baee06d1055d2bd5a5e9593 Mon Sep 17 00:00:00 2001 From: Li Bin Date: Wed, 29 Jul 2020 21:23:22 +0800 Subject: [PATCH 2/3] Update dcn_v2_cuda.cu --- src/cuda/dcn_v2_cuda.cu | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/cuda/dcn_v2_cuda.cu b/src/cuda/dcn_v2_cuda.cu index c90ee04..cef3068 100644 --- a/src/cuda/dcn_v2_cuda.cu +++ b/src/cuda/dcn_v2_cuda.cu @@ -357,15 +357,16 @@ std::vector dcn_v2_cuda_backward(const at::Tensor &input, cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); cublasOperation_t op = _cublasOpFromChar('t'); _cublasAdjustLdLevel2(k_, m_, &k_); - float* grad_output_n_float = grad_output_n.data_ptr(); - float* one_float = ones.data_ptr(); - float alpha = 1.0f; - float beta = 1.0f; - cublasSgemv(handle, op, k_, m_, &alpha, grad_output_n_float,k_, one_float,1, &beta, grad_bias.data_ptr(), 1); + scalar_t* grad_output_n_float = grad_output_n.data_ptr(); + scalar_t* one_float = ones.data_ptr(); + scalar_t alpha = 1.0; + scalar_t beta = 1.0; + cublasSgemv(handle, op, k_, m_, &alpha, grad_output_n_float,k_, one_float,1, &beta, grad_bias.data_ptr(), 1); + } return { grad_input, grad_offset, grad_mask, grad_weight, grad_bias }; -} \ No newline at end of file +} From 41c9ebb3f7e9e563529bb5efc831c3922de72b08 Mon Sep 17 00:00:00 2001 From: Li Bin Date: Mon, 23 Nov 2020 18:01:50 +0800 Subject: [PATCH 3/3] add torch 1.7 support --- dcn_v2.py | 51 ++--- setup.py | 7 +- src/cpu/dcn_v2_cpu.cpp | 118 +++++----- src/cuda/dcn_v2_cuda.cu | 125 ++++------- src/cuda/dcn_v2_psroi_pooling_cuda.cu | 24 +-- src/cuda/vision.h | 2 +- test.py => test/test.py | 197 +++++++++-------- test/testcpu.py | 295 +++++++++++++++++++++++++ test/testcuda.py | 299 ++++++++++++++++++++++++++ 9 files changed, 844 insertions(+), 274 deletions(-) rename test.py => test/test.py (57%) create mode 100644 test/testcpu.py create mode 100644 test/testcuda.py diff --git a/dcn_v2.py b/dcn_v2.py index ce654b5..a2014e3 100644 --- a/dcn_v2.py +++ b/dcn_v2.py @@ -14,9 +14,7 @@ class _DCNv2(Function): @staticmethod - def forward( - ctx, input, offset, mask, weight, bias, stride, padding, dilation, deformable_groups - ): + def forward(ctx, input, offset, mask, weight, bias, stride, padding, dilation, deformable_groups): ctx.stride = _pair(stride) ctx.padding = _pair(padding) ctx.dilation = _pair(dilation) @@ -63,30 +61,16 @@ def backward(ctx, grad_output): ctx.deformable_groups, ) - return grad_input, grad_offset, grad_mask, grad_weight, grad_bias, None, None, None, None - - @staticmethod - def symbolic( - g, input, offset, mask, weight, bias, stride, padding, dilation, deformable_groups - ): - from torch.nn.modules.utils import _pair - - stride = _pair(stride) - padding = _pair(padding) - dilation = _pair(dilation) - # as of trt 7, the dcn operation will be translated again by modifying the onnx file - # so the exporting code is kept to resemble the forward() - return g.op( - "DCNv2_2", - input, - offset, - mask, - weight, - bias, - stride_i=stride, - padding_i=padding, - dilation_i=dilation, - deformable_groups_i=deformable_groups, + return ( + grad_input, + grad_offset, + grad_mask, + grad_weight, + grad_bias, + None, + None, + None, + None, ) @@ -126,10 +110,7 @@ def reset_parameters(self): self.bias.data.zero_() def forward(self, input, offset, mask): - assert ( - 2 * self.deformable_groups * self.kernel_size[0] * self.kernel_size[1] - == offset.shape[1] - ) + assert 2 * self.deformable_groups * self.kernel_size[0] * self.kernel_size[1] == offset.shape[1] assert self.deformable_groups * self.kernel_size[0] * self.kernel_size[1] == mask.shape[1] return dcn_v2_conv( input, @@ -155,9 +136,7 @@ def __init__( dilation=1, deformable_groups=1, ): - super(DCN, self).__init__( - in_channels, out_channels, kernel_size, stride, padding, dilation, deformable_groups - ) + super(DCN, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, deformable_groups) channels_ = self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1] self.conv_offset_mask = nn.Conv2d( @@ -328,9 +307,7 @@ def __init__( if not no_trans: self.offset_mask_fc = nn.Sequential( - nn.Linear( - self.pooled_size * self.pooled_size * self.output_dim, self.deform_fc_dim - ), + nn.Linear(self.pooled_size * self.pooled_size * self.output_dim, self.deform_fc_dim), nn.ReLU(inplace=True), nn.Linear(self.deform_fc_dim, self.deform_fc_dim), nn.ReLU(inplace=True), diff --git a/setup.py b/setup.py index 7ee6edb..887cce2 100644 --- a/setup.py +++ b/setup.py @@ -58,7 +58,12 @@ def get_extensions(): author="charlesshang", url="https://github.com/charlesshang/DCNv2", description="deformable convolutional networks", - packages=find_packages(exclude=("configs", "tests")), + packages=find_packages( + exclude=( + "configs", + "tests", + ) + ), # install_requires=requirements, ext_modules=get_extensions(), cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, diff --git a/src/cpu/dcn_v2_cpu.cpp b/src/cpu/dcn_v2_cpu.cpp index 76c65f0..8d76c28 100644 --- a/src/cpu/dcn_v2_cpu.cpp +++ b/src/cpu/dcn_v2_cpu.cpp @@ -1,5 +1,6 @@ #include #include "cpu/dcn_v2_im2col_cpu.h" +#include #include //#include @@ -12,8 +13,12 @@ // author: Charles Shang // https://github.com/torch/cunn/blob/master/lib/THCUNN/generic/SpatialConvolutionMM.cu + // modified from the CUDA version for CPU use by Daniel K. Suhendro +// edit by: James Bockman and Matthew Howe +// modified for torch implementation to remove use of deprecated torch access to Blas + at::Tensor dcn_v2_cpu_forward(const at::Tensor &input, const at::Tensor &weight, @@ -60,9 +65,10 @@ dcn_v2_cpu_forward(const at::Tensor &input, const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; - auto ones = at::ones({height_out, width_out}, input.options()); + // auto ones = at::ones({height_out, width_out}, input.options()); + auto ones = at::ones({bias.sizes()[0], height_out, width_out}, input.options()); auto columns = at::empty({channels * kernel_h * kernel_w, 1 * height_out * width_out}, input.options()); - auto output = at::empty({batch, channels_out, height_out, width_out}, input.options()); + auto output = at::zeros({batch, channels_out, height_out, width_out}, input.options()); using scalar_t = float; for (int b = 0; b < batch; b++) @@ -71,37 +77,35 @@ dcn_v2_cpu_forward(const at::Tensor &input, auto offset_n = offset.select(0, b); auto mask_n = mask.select(0, b); auto output_n = output.select(0, b); + // std::cout << "output_n: " << output_n << "output.select(0,b): " << output.select(0,b) << "\n"; // Do Bias first: // M,N,K are dims of matrix A and B // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm) // (N x 1) (1 x M) - long m_ = channels_out; - long n_ = height_out * width_out; - long k_ = 1; - THFloatBlas_gemm('t', 'n', n_, m_, k_, 1.0f, - ones.contiguous().data(), k_, - bias.contiguous().data(), k_, 0.0f, - output_n.data(), n_); - - modulated_deformable_im2col_cpu(input_n.data(), - offset_n.data(), - mask_n.data(), + + // torch implementation + auto ones_T = at::transpose(ones.contiguous(), 2, 0); + ones_T = at::mul(ones_T, bias.contiguous()); + ones_T = at::transpose(ones_T, 2, 0); + output_n = at::add(output_n, ones_T); + + modulated_deformable_im2col_cpu(input_n.data_ptr(), + offset_n.data_ptr(), + mask_n.data_ptr(), 1, channels, height, width, height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, deformable_group, - columns.data()); + columns.data_ptr()); //(k * m) x (m * n) // Y = WC - long m = channels_out; - long n = height_out * width_out; - long k = channels * kernel_h * kernel_w; - THFloatBlas_gemm('n', 'n', n, m, k, 1.0f, - columns.data(), n, - weight.data(), k, 1.0f, - output_n.data(), n); + + // torch implementation + auto weight_flat = weight.view({channels_out, channels * kernel_h * kernel_w}); + auto product = at::matmul(weight_flat, columns); + output.select(0, b) = at::add(output_n, product.view({channels_out, height_out, width_out})); } return output; } @@ -148,7 +152,7 @@ std::vector dcn_v2_cpu_backward(const at::Tensor &input, const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; auto ones = at::ones({height_out, width_out}, input.options()); - auto columns = at::empty({channels * kernel_h * kernel_w, 1 * height_out * width_out}, input.options()); + auto columns = at::zeros({channels * kernel_h * kernel_w, 1 * height_out * width_out}, input.options()); auto output = at::empty({batch, channels_out, height_out, width_out}, input.options()); auto grad_input = at::zeros_like(input); @@ -169,65 +173,57 @@ std::vector dcn_v2_cpu_backward(const at::Tensor &input, auto grad_offset_n = grad_offset.select(0, b); auto grad_mask_n = grad_mask.select(0, b); - long m = channels * kernel_h * kernel_w; - long n = height_out * width_out; - long k = channels_out; - THFloatBlas_gemm('n', 't', n, m, k, 1.0f, - grad_output_n.data(), n, - weight.data(), m, 0.0f, - columns.data(), n); + + // Torch implementation + auto weight_flat = weight.view({channels_out, channels*kernel_h*kernel_w}); + weight_flat = at::transpose(weight_flat, 1, 0); + auto grad_output_n_flat = grad_output_n.view({channels_out, height_out*width_out}); + columns = at::matmul(weight_flat, grad_output_n_flat); // gradient w.r.t. input coordinate data - modulated_deformable_col2im_coord_cpu(columns.data(), - input_n.data(), - offset_n.data(), - mask_n.data(), + modulated_deformable_col2im_coord_cpu(columns.data_ptr(), + input_n.data_ptr(), + offset_n.data_ptr(), + mask_n.data_ptr(), 1, channels, height, width, height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, deformable_group, - grad_offset_n.data(), - grad_mask_n.data()); + grad_offset_n.data_ptr(), + grad_mask_n.data_ptr()); // gradient w.r.t. input data - modulated_deformable_col2im_cpu(columns.data(), - offset_n.data(), - mask_n.data(), + modulated_deformable_col2im_cpu(columns.data_ptr(), + offset_n.data_ptr(), + mask_n.data_ptr(), 1, channels, height, width, height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, deformable_group, - grad_input_n.data()); + grad_input_n.data_ptr()); // gradient w.r.t. weight, dWeight should accumulate across the batch and group - modulated_deformable_im2col_cpu(input_n.data(), - offset_n.data(), - mask_n.data(), + modulated_deformable_im2col_cpu(input_n.data_ptr(), + offset_n.data_ptr(), + mask_n.data_ptr(), 1, channels, height, width, height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, deformable_group, - columns.data()); - - long m_ = channels_out; - long n_ = channels * kernel_h * kernel_w; - long k_ = height_out * width_out; - - THFloatBlas_gemm('t', 'n', n_, m_, k_, 1.0f, - columns.data(), k_, - grad_output_n.data(), k_, 1.0f, - grad_weight.data(), n_); - - // gradient w.r.t. bias - // long m_ = channels_out; - // long k__ = height_out * width_out; - // THFloatBlas_gemv('t', k_, m_, 1.0f, - // grad_output_n.data(), k_, - // ones.data(), 1, 1.0f, - // grad_bias.data(), 1); + columns.data_ptr()); + + // Torch implementation + auto product = at::matmul(grad_output_n_flat, at::transpose(columns, 1, 0)); + grad_weight = at::add(grad_weight, product.view({channels_out, channels, kernel_h, kernel_w})); + + + // Torch implementation + auto ones_flat = ones.view({height_out*width_out}); + product = at::matmul(grad_output_n_flat, ones_flat); + grad_bias = at::add(grad_bias, product); } return { grad_input, grad_offset, grad_mask, grad_weight, grad_bias }; -} \ No newline at end of file +} diff --git a/src/cuda/dcn_v2_cuda.cu b/src/cuda/dcn_v2_cuda.cu index cef3068..a23180e 100644 --- a/src/cuda/dcn_v2_cuda.cu +++ b/src/cuda/dcn_v2_cuda.cu @@ -3,49 +3,13 @@ #include #include -#include -#include -#include + #include #include #include -#include -#include THCState *state = at::globalContext().lazyInitCUDA(); -static cublasOperation_t _cublasOpFromChar(char op) { - switch (op) { - case 'n': - case 'N': - return CUBLAS_OP_N; - case 't': - case 'T': - return CUBLAS_OP_T; - case 'c': - case 'C': - return CUBLAS_OP_C; - } - AT_ERROR( - "_cublasOpFromChar input should be 't', 'n' or 'c' but got `", op, "`"); - } - - static void _cublasAdjustLdLevel2(int64_t m, int64_t n, int64_t* lda) { - // Note: leading dimensions generally are checked that they are > 0 - // and at least as big the result requires (even if the value won't - // be used). - - // Q: Why does Level3 check trans but this doesn't? - // A: In level 2, the sizes (m, n) specify the size of A - // (independent of trans value). In level 3. the sizes (m, n, k) - // specify the sizes of op(A), op(B) where op depend on trans - // values. - if (n <= 1) - *lda = std::max(m, 1); - } - - - // author: Charles Shang // https://github.com/torch/cunn/blob/master/lib/THCUNN/generic/SpatialConvolutionMM.cu @@ -144,12 +108,12 @@ dcn_v2_cuda_forward(const at::Tensor &input, input_b, output_b, columns_b, ones_b, weight_b, bias_b, - input.data_ptr(), - output.data_ptr(), - columns.data_ptr(), - ones.data_ptr(), - weight.data_ptr(), - bias.data_ptr(), + input.data(), + output.data(), + columns.data(), + ones.data(), + weight.data(), + bias.data(), channels * width * height, channels_out * width_out * height_out, channels * kernel_h * kernel_w * height_out * width_out, @@ -173,14 +137,14 @@ dcn_v2_cuda_forward(const at::Tensor &input, batch); modulated_deformable_im2col_cuda(c10::cuda::getCurrentCUDAStream(), - input.data_ptr(), - offset.data_ptr(), - mask.data_ptr(), + input.data(), + offset.data(), + mask.data(), batch, channels, height, width, height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, deformable_group, - columns.data_ptr()); + columns.data()); long m = channels_out; long n = height_out * width_out; @@ -307,64 +271,69 @@ std::vector dcn_v2_cuda_backward(const at::Tensor &input, long k = channels_out; THCudaBlas_Sgemm(state, 'n', 't', n, m, k, 1.0f, - grad_output_n.data_ptr(), n, - weight.data_ptr(), m, 0.0f, - columns.data_ptr(), n); + grad_output_n.data(), n, + weight.data(), m, 0.0f, + columns.data(), n); // gradient w.r.t. input coordinate data modulated_deformable_col2im_coord_cuda(c10::cuda::getCurrentCUDAStream(), - columns.data_ptr(), - input_n.data_ptr(), - offset_n.data_ptr(), - mask_n.data_ptr(), + columns.data(), + input_n.data(), + offset_n.data(), + mask_n.data(), 1, channels, height, width, height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, deformable_group, - grad_offset_n.data_ptr(), - grad_mask_n.data_ptr()); + grad_offset_n.data(), + grad_mask_n.data()); // gradient w.r.t. input data modulated_deformable_col2im_cuda(c10::cuda::getCurrentCUDAStream(), - columns.data_ptr(), - offset_n.data_ptr(), - mask_n.data_ptr(), + columns.data(), + offset_n.data(), + mask_n.data(), 1, channels, height, width, height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, deformable_group, - grad_input_n.data_ptr()); + grad_input_n.data()); // gradient w.r.t. weight, dWeight should accumulate across the batch and group modulated_deformable_im2col_cuda(c10::cuda::getCurrentCUDAStream(), - input_n.data_ptr(), - offset_n.data_ptr(), - mask_n.data_ptr(), + input_n.data(), + offset_n.data(), + mask_n.data(), 1, channels, height, width, height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, deformable_group, - columns.data_ptr()); + columns.data()); long m_ = channels_out; long n_ = channels * kernel_h * kernel_w; long k_ = height_out * width_out; THCudaBlas_Sgemm(state, 't', 'n', n_, m_, k_, 1.0f, - columns.data_ptr(), k_, - grad_output_n.data_ptr(), k_, 1.0f, - grad_weight.data_ptr(), n_); - - cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); - cublasOperation_t op = _cublasOpFromChar('t'); - _cublasAdjustLdLevel2(k_, m_, &k_); - scalar_t* grad_output_n_float = grad_output_n.data_ptr(); - scalar_t* one_float = ones.data_ptr(); - scalar_t alpha = 1.0; - scalar_t beta = 1.0; - cublasSgemv(handle, op, k_, m_, &alpha, grad_output_n_float,k_, one_float,1, &beta, grad_bias.data_ptr(), 1); - + columns.data(), k_, + grad_output_n.data(), k_, 1.0f, + grad_weight.data(), n_); + + // gradient w.r.t. bias + // long m_ = channels_out; + // long k__ = height_out * width_out; + // THCudaBlas_Sgemm(state, + // 't', 'n', + // k_, m_, 1, 1.0f, + // grad_output_n.data(), k_, + // ones.data(), 1, 1.0f, + // grad_bias.data(), 1); + THCudaBlas_Sgemm(state, + 'N', 'N', 1, m_, k_, 1.0f, + ones.data(), 1, + grad_output_n.data(), k_, + 1.0f, + grad_bias.data(), 1); } - return { grad_input, grad_offset, grad_mask, grad_weight, grad_bias diff --git a/src/cuda/dcn_v2_psroi_pooling_cuda.cu b/src/cuda/dcn_v2_psroi_pooling_cuda.cu index bf99f0c..8f08f6a 100644 --- a/src/cuda/dcn_v2_psroi_pooling_cuda.cu +++ b/src/cuda/dcn_v2_psroi_pooling_cuda.cu @@ -317,14 +317,14 @@ dcn_v2_psroi_pooling_cuda_forward(const at::Tensor &input, AT_DISPATCH_FLOATING_TYPES(input.type(), "dcn_v2_psroi_pooling_cuda_forward", [&] { DeformablePSROIPoolForwardKernelCuda<<>>( out_size, - input.contiguous().data_ptr(), + input.contiguous().data(), spatial_scale, channels, height, width, pooled_height, pooled_width, - bbox.contiguous().data_ptr(), - trans.contiguous().data_ptr(), + bbox.contiguous().data(), + trans.contiguous().data(), no_trans, trans_std, sample_per_part, @@ -333,8 +333,8 @@ dcn_v2_psroi_pooling_cuda_forward(const at::Tensor &input, part_size, num_classes, channels_each_class, - out.data_ptr(), - top_count.data_ptr()); + out.data(), + top_count.data()); }); THCudaCheck(cudaGetLastError()); return std::make_tuple(out, top_count); @@ -391,8 +391,8 @@ dcn_v2_psroi_pooling_cuda_backward(const at::Tensor &out_grad, AT_DISPATCH_FLOATING_TYPES(out_grad.type(), "dcn_v2_psroi_pooling_cuda_backward", [&] { DeformablePSROIPoolBackwardAccKernelCuda<<>>( out_size, - out_grad.contiguous().data_ptr(), - top_count.contiguous().data_ptr(), + out_grad.contiguous().data(), + top_count.contiguous().data(), num_bbox, spatial_scale, channels, @@ -401,11 +401,11 @@ dcn_v2_psroi_pooling_cuda_backward(const at::Tensor &out_grad, pooled_height, pooled_width, output_dim, - input_grad.contiguous().data_ptr(), - trans_grad.contiguous().data_ptr(), - input.contiguous().data_ptr(), - bbox.contiguous().data_ptr(), - trans.contiguous().data_ptr(), + input_grad.contiguous().data(), + trans_grad.contiguous().data(), + input.contiguous().data(), + bbox.contiguous().data(), + trans.contiguous().data(), no_trans, trans_std, sample_per_part, diff --git a/src/cuda/vision.h b/src/cuda/vision.h index f3672b1..e42a2a7 100644 --- a/src/cuda/vision.h +++ b/src/cuda/vision.h @@ -1,6 +1,6 @@ #pragma once #include -#include + at::Tensor dcn_v2_cuda_forward(const at::Tensor &input, const at::Tensor &weight, diff --git a/test.py b/test/test.py similarity index 57% rename from test.py rename to test/test.py index 3bd5bd2..cade313 100644 --- a/test.py +++ b/test/test.py @@ -1,15 +1,11 @@ #!/usr/bin/env python -from __future__ import absolute_import -from __future__ import print_function -from __future__ import division +from __future__ import absolute_import, division, print_function -import time import torch import torch.nn as nn from torch.autograd import gradcheck -from dcn_v2 import dcn_v2_conv, DCNv2, DCN -from dcn_v2 import dcn_v2_pooling, DCNv2Pooling, DCNPooling +from dcn_v2 import DCN, DCNPooling, DCNv2, DCNv2Pooling, dcn_v2_conv, dcn_v2_pooling deformable_groups = 1 N, inC, inH, inW = 2, 2, 4, 4 @@ -21,8 +17,8 @@ def conv_identify(weight, bias): weight.data.zero_() bias.data.zero_() o, i, h, w = weight.shape - y = h//2 - x = w//2 + y = h // 2 + x = w // 2 for p in range(i): for q in range(o): if p == q: @@ -30,21 +26,25 @@ def conv_identify(weight, bias): def check_zero_offset(): - conv_offset = nn.Conv2d(inC, deformable_groups * 2 * kH * kW, - kernel_size=(kH, kW), - stride=(1, 1), - padding=(1, 1), - bias=True).cuda() - - conv_mask = nn.Conv2d(inC, deformable_groups * 1 * kH * kW, - kernel_size=(kH, kW), - stride=(1, 1), - padding=(1, 1), - bias=True).cuda() - - dcn_v2 = DCNv2(inC, outC, (kH, kW), - stride=1, padding=1, dilation=1, - deformable_groups=deformable_groups).cuda() + conv_offset = nn.Conv2d( + inC, + deformable_groups * 2 * kH * kW, + kernel_size=(kH, kW), + stride=(1, 1), + padding=(1, 1), + bias=True, + ).cuda() + + conv_mask = nn.Conv2d( + inC, + deformable_groups * 1 * kH * kW, + kernel_size=(kH, kW), + stride=(1, 1), + padding=(1, 1), + bias=True, + ).cuda() + + dcn_v2 = DCNv2(inC, outC, (kH, kW), stride=1, padding=1, dilation=1, deformable_groups=deformable_groups).cuda() conv_offset.weight.data.zero_() conv_offset.bias.data.zero_() @@ -60,12 +60,13 @@ def check_zero_offset(): output *= 2 d = (input - output).abs().max() if d < 1e-10: - print('Zero offset passed') + print("Zero offset passed") else: - print('Zero offset failed') + print("Zero offset failed") print(input) print(output) + def check_gradient_dconv(): input = torch.rand(N, inC, inH, inW).cuda() * 0.01 @@ -91,43 +92,57 @@ def check_gradient_dconv(): padding = 1 dilation = 1 - print('check_gradient_dconv: ', - gradcheck(dcn_v2_conv, (input, offset, mask, weight, bias, - stride, padding, dilation, deformable_groups), - eps=1e-3, atol=1e-4, rtol=1e-2)) + print( + "check_gradient_dconv: ", + gradcheck( + dcn_v2_conv, + (input, offset, mask, weight, bias, stride, padding, dilation, deformable_groups), + eps=1e-3, + atol=1e-4, + rtol=1e-2, + ), + ) def check_pooling_zero_offset(): input = torch.randn(2, 16, 64, 64).cuda().zero_() - input[0, :, 16:26, 16:26] = 1. - input[1, :, 10:20, 20:30] = 2. - rois = torch.tensor([ - [0, 65, 65, 103, 103], - [1, 81, 41, 119, 79], - ]).cuda().float() - pooling = DCNv2Pooling(spatial_scale=1.0 / 4, - pooled_size=7, - output_dim=16, - no_trans=True, - group_size=1, - trans_std=0.0).cuda() + input[0, :, 16:26, 16:26] = 1.0 + input[1, :, 10:20, 20:30] = 2.0 + rois = ( + torch.tensor( + [ + [0, 65, 65, 103, 103], + [1, 81, 41, 119, 79], + ] + ) + .cuda() + .float() + ) + pooling = DCNv2Pooling( + spatial_scale=1.0 / 4, + pooled_size=7, + output_dim=16, + no_trans=True, + group_size=1, + trans_std=0.0, + ).cuda() out = pooling(input, rois, input.new()) - s = ', '.join(['%f' % out[i, :, :, :].mean().item() - for i in range(rois.shape[0])]) + s = ", ".join(["%f" % out[i, :, :, :].mean().item() for i in range(rois.shape[0])]) print(s) - dpooling = DCNv2Pooling(spatial_scale=1.0 / 4, - pooled_size=7, - output_dim=16, - no_trans=False, - group_size=1, - trans_std=0.0).cuda() + dpooling = DCNv2Pooling( + spatial_scale=1.0 / 4, + pooled_size=7, + output_dim=16, + no_trans=False, + group_size=1, + trans_std=0.0, + ).cuda() offset = torch.randn(20, 2, 7, 7).cuda().zero_() dout = dpooling(input, rois, offset) - s = ', '.join(['%f' % dout[i, :, :, :].mean().item() - for i in range(rois.shape[0])]) + s = ", ".join(["%f" % dout[i, :, :, :].mean().item() for i in range(rois.shape[0])]) print(s) @@ -153,24 +168,32 @@ def check_gradient_dpooling(): sample_per_part = 4 part_size = pooled_size - print('check_gradient_dpooling:', - gradcheck(dcn_v2_pooling, (input, rois, offset, - spatial_scale, - pooled_size, - output_dim, - no_trans, - group_size, - part_size, - sample_per_part, - trans_std), - eps=1e-4)) + print( + "check_gradient_dpooling:", + gradcheck( + dcn_v2_pooling, + ( + input, + rois, + offset, + spatial_scale, + pooled_size, + output_dim, + no_trans, + group_size, + part_size, + sample_per_part, + trans_std, + ), + eps=1e-4, + ), + ) def example_dconv(): input = torch.randn(2, 64, 128, 128).cuda() # wrap all things (offset and mask) in DCN - dcn = DCN(64, 64, kernel_size=(3, 3), stride=1, - padding=1, deformable_groups=2).cuda() + dcn = DCN(64, 64, kernel_size=(3, 3), stride=1, padding=1, deformable_groups=2).cuda() # print(dcn.weight.shape, input.shape) output = dcn(input) targert = output.new(*output.size()) @@ -193,20 +216,24 @@ def example_dpooling(): offset.requires_grad = True # normal roi_align - pooling = DCNv2Pooling(spatial_scale=1.0 / 4, - pooled_size=7, - output_dim=32, - no_trans=True, - group_size=1, - trans_std=0.1).cuda() + pooling = DCNv2Pooling( + spatial_scale=1.0 / 4, + pooled_size=7, + output_dim=32, + no_trans=True, + group_size=1, + trans_std=0.1, + ).cuda() # deformable pooling - dpooling = DCNv2Pooling(spatial_scale=1.0 / 4, - pooled_size=7, - output_dim=32, - no_trans=False, - group_size=1, - trans_std=0.1).cuda() + dpooling = DCNv2Pooling( + spatial_scale=1.0 / 4, + pooled_size=7, + output_dim=32, + no_trans=False, + group_size=1, + trans_std=0.1, + ).cuda() out = pooling(input, rois, offset) dout = dpooling(input, rois, offset) @@ -234,13 +261,15 @@ def example_mdpooling(): rois = torch.cat((batch_inds, x, y, x + w, y + h), dim=1) # mdformable pooling (V2) - dpooling = DCNPooling(spatial_scale=1.0 / 4, - pooled_size=7, - output_dim=32, - no_trans=False, - group_size=1, - trans_std=0.1, - deform_fc_dim=1024).cuda() + dpooling = DCNPooling( + spatial_scale=1.0 / 4, + pooled_size=7, + output_dim=32, + no_trans=False, + group_size=1, + trans_std=0.1, + deform_fc_dim=1024, + ).cuda() dout = dpooling(input, rois) target = dout.new(*dout.size()) @@ -250,7 +279,7 @@ def example_mdpooling(): print(dout.shape) -if __name__ == '__main__': +if __name__ == "__main__": example_dconv() example_dpooling() diff --git a/test/testcpu.py b/test/testcpu.py new file mode 100644 index 0000000..c278107 --- /dev/null +++ b/test/testcpu.py @@ -0,0 +1,295 @@ +#!/usr/bin/env python +from __future__ import absolute_import, division, print_function + +import torch +import torch.nn as nn +from torch.autograd import gradcheck + +from dcn_v2 import DCN, DCNPooling, DCNv2, DCNv2Pooling, dcn_v2_conv, dcn_v2_pooling + +deformable_groups = 1 +N, inC, inH, inW = 2, 2, 4, 4 +outC = 2 +kH, kW = 3, 3 + + +def conv_identify(weight, bias): + weight.data.zero_() + bias.data.zero_() + o, i, h, w = weight.shape + y = h // 2 + x = w // 2 + for p in range(i): + for q in range(o): + if p == q: + weight.data[q, p, y, x] = 1.0 + + +def check_zero_offset(): + conv_offset = nn.Conv2d( + inC, + deformable_groups * 2 * kH * kW, + kernel_size=(kH, kW), + stride=(1, 1), + padding=(1, 1), + bias=True, + ) + + conv_mask = nn.Conv2d( + inC, + deformable_groups * 1 * kH * kW, + kernel_size=(kH, kW), + stride=(1, 1), + padding=(1, 1), + bias=True, + ) + + dcn_v2 = DCNv2(inC, outC, (kH, kW), stride=1, padding=1, dilation=1, deformable_groups=deformable_groups) + + conv_offset.weight.data.zero_() + conv_offset.bias.data.zero_() + conv_mask.weight.data.zero_() + conv_mask.bias.data.zero_() + conv_identify(dcn_v2.weight, dcn_v2.bias) + + input = torch.randn(N, inC, inH, inW) + offset = conv_offset(input) + mask = conv_mask(input) + mask = torch.sigmoid(mask) + output = dcn_v2(input, offset, mask) + output *= 2 + d = (input - output).abs().max() + if d < 1e-10: + print("Zero offset passed") + else: + print("Zero offset failed") + print(input) + print(output) + + +def check_gradient_dconv(): + + input = torch.rand(N, inC, inH, inW) * 0.01 + input.requires_grad = True + + offset = torch.randn(N, deformable_groups * 2 * kW * kH, inH, inW) * 2 + # offset.data.zero_() + # offset.data -= 0.5 + offset.requires_grad = True + + mask = torch.rand(N, deformable_groups * 1 * kW * kH, inH, inW) + # mask.data.zero_() + mask.requires_grad = True + mask = torch.sigmoid(mask) + + weight = torch.randn(outC, inC, kH, kW) + weight.requires_grad = True + + bias = torch.rand(outC) + bias.requires_grad = True + + stride = 1 + padding = 1 + dilation = 1 + + print( + "check_gradient_dconv: ", + gradcheck( + dcn_v2_conv, + (input, offset, mask, weight, bias, stride, padding, dilation, deformable_groups), + eps=1e-3, + atol=1e-4, + rtol=1e-2, + ), + ) + + +def check_pooling_zero_offset(): + + input = torch.randn(2, 16, 64, 64).zero_() + input[0, :, 16:26, 16:26] = 1.0 + input[1, :, 10:20, 20:30] = 2.0 + rois = torch.tensor( + [ + [0, 65, 65, 103, 103], + [1, 81, 41, 119, 79], + ] + ).float() + pooling = DCNv2Pooling( + spatial_scale=1.0 / 4, + pooled_size=7, + output_dim=16, + no_trans=True, + group_size=1, + trans_std=0.0, + ) + + out = pooling(input, rois, input.new()) + s = ", ".join(["%f" % out[i, :, :, :].mean().item() for i in range(rois.shape[0])]) + print(s) + + dpooling = DCNv2Pooling( + spatial_scale=1.0 / 4, + pooled_size=7, + output_dim=16, + no_trans=False, + group_size=1, + trans_std=0.0, + ) + offset = torch.randn(20, 2, 7, 7).zero_() + dout = dpooling(input, rois, offset) + s = ", ".join(["%f" % dout[i, :, :, :].mean().item() for i in range(rois.shape[0])]) + print(s) + + +def check_gradient_dpooling(): + input = torch.randn(2, 3, 5, 5) * 0.01 + N = 4 + batch_inds = torch.randint(2, (N, 1)).float() + x = torch.rand((N, 1)).float() * 15 + y = torch.rand((N, 1)).float() * 15 + w = torch.rand((N, 1)).float() * 10 + h = torch.rand((N, 1)).float() * 10 + rois = torch.cat((batch_inds, x, y, x + w, y + h), dim=1) + offset = torch.randn(N, 2, 3, 3) + input.requires_grad = True + offset.requires_grad = True + + spatial_scale = 1.0 / 4 + pooled_size = 3 + output_dim = 3 + no_trans = 0 + group_size = 1 + trans_std = 0.0 + sample_per_part = 4 + part_size = pooled_size + + print( + "check_gradient_dpooling:", + gradcheck( + dcn_v2_pooling, + ( + input, + rois, + offset, + spatial_scale, + pooled_size, + output_dim, + no_trans, + group_size, + part_size, + sample_per_part, + trans_std, + ), + eps=1e-4, + ), + ) + + +def example_dconv(): + input = torch.randn(2, 64, 128, 128) + # wrap all things (offset and mask) in DCN + dcn = DCN(64, 64, kernel_size=(3, 3), stride=1, padding=1, deformable_groups=2) + # print(dcn.weight.shape, input.shape) + output = dcn(input) + targert = output.new(*output.size()) + targert.data.uniform_(-0.01, 0.01) + error = (targert - output).mean() + error.backward() + print(output.shape) + + +def example_dpooling(): + input = torch.randn(2, 32, 64, 64) + batch_inds = torch.randint(2, (20, 1)).float() + x = torch.randint(256, (20, 1)).float() + y = torch.randint(256, (20, 1)).float() + w = torch.randint(64, (20, 1)).float() + h = torch.randint(64, (20, 1)).float() + rois = torch.cat((batch_inds, x, y, x + w, y + h), dim=1) + offset = torch.randn(20, 2, 7, 7) + input.requires_grad = True + offset.requires_grad = True + + # normal roi_align + pooling = DCNv2Pooling( + spatial_scale=1.0 / 4, + pooled_size=7, + output_dim=32, + no_trans=True, + group_size=1, + trans_std=0.1, + ) + + # deformable pooling + dpooling = DCNv2Pooling( + spatial_scale=1.0 / 4, + pooled_size=7, + output_dim=32, + no_trans=False, + group_size=1, + trans_std=0.1, + ) + + out = pooling(input, rois, offset) + dout = dpooling(input, rois, offset) + print(out.shape) + print(dout.shape) + + target_out = out.new(*out.size()) + target_out.data.uniform_(-0.01, 0.01) + target_dout = dout.new(*dout.size()) + target_dout.data.uniform_(-0.01, 0.01) + e = (target_out - out).mean() + e.backward() + e = (target_dout - dout).mean() + e.backward() + + +def example_mdpooling(): + input = torch.randn(2, 32, 64, 64) + input.requires_grad = True + batch_inds = torch.randint(2, (20, 1)).float() + x = torch.randint(256, (20, 1)).float() + y = torch.randint(256, (20, 1)).float() + w = torch.randint(64, (20, 1)).float() + h = torch.randint(64, (20, 1)).float() + rois = torch.cat((batch_inds, x, y, x + w, y + h), dim=1) + + # mdformable pooling (V2) + dpooling = DCNPooling( + spatial_scale=1.0 / 4, + pooled_size=7, + output_dim=32, + no_trans=False, + group_size=1, + trans_std=0.1, + deform_fc_dim=1024, + ) + + dout = dpooling(input, rois) + target = dout.new(*dout.size()) + target.data.uniform_(-0.1, 0.1) + error = (target - dout).mean() + error.backward() + print(dout.shape) + + +if __name__ == "__main__": + + example_dconv() + example_dpooling() + example_mdpooling() + + check_pooling_zero_offset() + # zero offset check + if inC == outC: + check_zero_offset() + + check_gradient_dpooling() + check_gradient_dconv() + # """ + # ****** Note: backward is not reentrant error may not be a serious problem, + # ****** since the max error is less than 1e-7, + # ****** Still looking for what trigger this problem + # """ diff --git a/test/testcuda.py b/test/testcuda.py new file mode 100644 index 0000000..b83a4aa --- /dev/null +++ b/test/testcuda.py @@ -0,0 +1,299 @@ +#!/usr/bin/env python +from __future__ import absolute_import, division, print_function + +import torch +import torch.nn as nn +from torch.autograd import gradcheck + +from dcn_v2 import DCN, DCNPooling, DCNv2, DCNv2Pooling, dcn_v2_conv, dcn_v2_pooling + +deformable_groups = 1 +N, inC, inH, inW = 2, 2, 4, 4 +outC = 2 +kH, kW = 3, 3 + + +def conv_identify(weight, bias): + weight.data.zero_() + bias.data.zero_() + o, i, h, w = weight.shape + y = h // 2 + x = w // 2 + for p in range(i): + for q in range(o): + if p == q: + weight.data[q, p, y, x] = 1.0 + + +def check_zero_offset(): + conv_offset = nn.Conv2d( + inC, + deformable_groups * 2 * kH * kW, + kernel_size=(kH, kW), + stride=(1, 1), + padding=(1, 1), + bias=True, + ).cuda() + + conv_mask = nn.Conv2d( + inC, + deformable_groups * 1 * kH * kW, + kernel_size=(kH, kW), + stride=(1, 1), + padding=(1, 1), + bias=True, + ).cuda() + + dcn_v2 = DCNv2(inC, outC, (kH, kW), stride=1, padding=1, dilation=1, deformable_groups=deformable_groups).cuda() + + conv_offset.weight.data.zero_() + conv_offset.bias.data.zero_() + conv_mask.weight.data.zero_() + conv_mask.bias.data.zero_() + conv_identify(dcn_v2.weight, dcn_v2.bias) + + input = torch.randn(N, inC, inH, inW).cuda() + offset = conv_offset(input) + mask = conv_mask(input) + mask = torch.sigmoid(mask) + output = dcn_v2(input, offset, mask) + output *= 2 + d = (input - output).abs().max() + if d < 1e-10: + print("Zero offset passed") + else: + print("Zero offset failed") + print(input) + print(output) + + +def check_gradient_dconv(): + + input = torch.rand(N, inC, inH, inW).cuda() * 0.01 + input.requires_grad = True + + offset = torch.randn(N, deformable_groups * 2 * kW * kH, inH, inW).cuda() * 2 + # offset.data.zero_() + # offset.data -= 0.5 + offset.requires_grad = True + + mask = torch.rand(N, deformable_groups * 1 * kW * kH, inH, inW).cuda() + # mask.data.zero_() + mask.requires_grad = True + mask = torch.sigmoid(mask) + + weight = torch.randn(outC, inC, kH, kW).cuda() + weight.requires_grad = True + + bias = torch.rand(outC).cuda() + bias.requires_grad = True + + stride = 1 + padding = 1 + dilation = 1 + + print( + "check_gradient_dconv: ", + gradcheck( + dcn_v2_conv, + (input, offset, mask, weight, bias, stride, padding, dilation, deformable_groups), + eps=1e-3, + atol=1e-4, + rtol=1e-2, + ), + ) + + +def check_pooling_zero_offset(): + + input = torch.randn(2, 16, 64, 64).cuda().zero_() + input[0, :, 16:26, 16:26] = 1.0 + input[1, :, 10:20, 20:30] = 2.0 + rois = ( + torch.tensor( + [ + [0, 65, 65, 103, 103], + [1, 81, 41, 119, 79], + ] + ) + .cuda() + .float() + ) + pooling = DCNv2Pooling( + spatial_scale=1.0 / 4, + pooled_size=7, + output_dim=16, + no_trans=True, + group_size=1, + trans_std=0.0, + ).cuda() + + out = pooling(input, rois, input.new()) + s = ", ".join(["%f" % out[i, :, :, :].mean().item() for i in range(rois.shape[0])]) + print(s) + + dpooling = DCNv2Pooling( + spatial_scale=1.0 / 4, + pooled_size=7, + output_dim=16, + no_trans=False, + group_size=1, + trans_std=0.0, + ).cuda() + offset = torch.randn(20, 2, 7, 7).cuda().zero_() + dout = dpooling(input, rois, offset) + s = ", ".join(["%f" % dout[i, :, :, :].mean().item() for i in range(rois.shape[0])]) + print(s) + + +def check_gradient_dpooling(): + input = torch.randn(2, 3, 5, 5).cuda().float() * 0.01 + N = 4 + batch_inds = torch.randint(2, (N, 1)).cuda().float() + x = torch.rand((N, 1)).cuda().float() * 15 + y = torch.rand((N, 1)).cuda().float() * 15 + w = torch.rand((N, 1)).cuda().float() * 10 + h = torch.rand((N, 1)).cuda().float() * 10 + rois = torch.cat((batch_inds, x, y, x + w, y + h), dim=1) + offset = torch.randn(N, 2, 3, 3).cuda() + input.requires_grad = True + offset.requires_grad = True + + spatial_scale = 1.0 / 4 + pooled_size = 3 + output_dim = 3 + no_trans = 0 + group_size = 1 + trans_std = 0.0 + sample_per_part = 4 + part_size = pooled_size + + print( + "check_gradient_dpooling:", + gradcheck( + dcn_v2_pooling, + ( + input, + rois, + offset, + spatial_scale, + pooled_size, + output_dim, + no_trans, + group_size, + part_size, + sample_per_part, + trans_std, + ), + eps=1e-4, + ), + ) + + +def example_dconv(): + input = torch.randn(2, 64, 128, 128).cuda() + # wrap all things (offset and mask) in DCN + dcn = DCN(64, 64, kernel_size=(3, 3), stride=1, padding=1, deformable_groups=2).cuda() + # print(dcn.weight.shape, input.shape) + output = dcn(input) + targert = output.new(*output.size()) + targert.data.uniform_(-0.01, 0.01) + error = (targert - output).mean() + error.backward() + print(output.shape) + + +def example_dpooling(): + input = torch.randn(2, 32, 64, 64).cuda() + batch_inds = torch.randint(2, (20, 1)).cuda().float() + x = torch.randint(256, (20, 1)).cuda().float() + y = torch.randint(256, (20, 1)).cuda().float() + w = torch.randint(64, (20, 1)).cuda().float() + h = torch.randint(64, (20, 1)).cuda().float() + rois = torch.cat((batch_inds, x, y, x + w, y + h), dim=1) + offset = torch.randn(20, 2, 7, 7).cuda() + input.requires_grad = True + offset.requires_grad = True + + # normal roi_align + pooling = DCNv2Pooling( + spatial_scale=1.0 / 4, + pooled_size=7, + output_dim=32, + no_trans=True, + group_size=1, + trans_std=0.1, + ).cuda() + + # deformable pooling + dpooling = DCNv2Pooling( + spatial_scale=1.0 / 4, + pooled_size=7, + output_dim=32, + no_trans=False, + group_size=1, + trans_std=0.1, + ).cuda() + + out = pooling(input, rois, offset) + dout = dpooling(input, rois, offset) + print(out.shape) + print(dout.shape) + + target_out = out.new(*out.size()) + target_out.data.uniform_(-0.01, 0.01) + target_dout = dout.new(*dout.size()) + target_dout.data.uniform_(-0.01, 0.01) + e = (target_out - out).mean() + e.backward() + e = (target_dout - dout).mean() + e.backward() + + +def example_mdpooling(): + input = torch.randn(2, 32, 64, 64).cuda() + input.requires_grad = True + batch_inds = torch.randint(2, (20, 1)).cuda().float() + x = torch.randint(256, (20, 1)).cuda().float() + y = torch.randint(256, (20, 1)).cuda().float() + w = torch.randint(64, (20, 1)).cuda().float() + h = torch.randint(64, (20, 1)).cuda().float() + rois = torch.cat((batch_inds, x, y, x + w, y + h), dim=1) + + # mdformable pooling (V2) + dpooling = DCNPooling( + spatial_scale=1.0 / 4, + pooled_size=7, + output_dim=32, + no_trans=False, + group_size=1, + trans_std=0.1, + deform_fc_dim=1024, + ).cuda() + + dout = dpooling(input, rois) + target = dout.new(*dout.size()) + target.data.uniform_(-0.1, 0.1) + error = (target - dout).mean() + error.backward() + print(dout.shape) + + +if __name__ == "__main__": + + example_dconv() + example_dpooling() + example_mdpooling() + + check_pooling_zero_offset() + # zero offset check + if inC == outC: + check_zero_offset() + + check_gradient_dpooling() + check_gradient_dconv() + # """ + # ****** Note: backward is not reentrant error may not be a serious problem, + # ****** since the max error is less than 1e-7, + # ****** Still looking for what trigger this problem + # """