Skip to content

Commit

Permalink
Add cuda streams to rulebook iterations
Browse files Browse the repository at this point in the history
  • Loading branch information
mpmisko committed Aug 23, 2019
1 parent 1171aae commit c701936
Show file tree
Hide file tree
Showing 7 changed files with 574 additions and 429 deletions.
106 changes: 51 additions & 55 deletions sparseconvnet/SCN/CUDA/AveragePooling.cu
Original file line number Diff line number Diff line change
Expand Up @@ -9,137 +9,133 @@
// NTX must be >=2 so r is filled properly
template <typename T, Int NTX, Int NTY>
__global__ void AveragePooling_fp(T *input_features, T *output_features,
Int nPlanes, Int input_stride,
Int output_stride, Int *rules, Int nHot,
T alpha) {
Int nPlanes, Int input_stride,
Int output_stride, Int *rules, Int nHot,
T alpha) {
__shared__ Int r[NTY * 2];
for (Int n = blockIdx.x * NTY; n < nHot; n += gridDim.x * NTY) {
{
Int i = threadIdx.x + NTX * threadIdx.y;
if (i < NTY * 2 and i < 2 * (nHot - n))
r[i] = rules[2 * n + i];
r[i] = rules[2 * n + i];
}
__syncthreads();
if (n + threadIdx.y < nHot) {
Int i = r[2 * threadIdx.y] * input_stride;
Int o = r[2 * threadIdx.y + 1] * output_stride;
for (Int plane = threadIdx.x; plane < nPlanes; plane += NTX)
output_features[o + plane]+= alpha * input_features[i + plane];
// atomicAdd(&output_features[o + plane],
// alpha * input_features[i + plane]);
output_features[o + plane] += alpha * input_features[i + plane];
// atomicAdd(&output_features[o + plane],
// alpha * input_features[i + plane]);
}
__syncthreads();
}
}

template <typename T>
void cuda_AveragePooling_ForwardPass(T *input_features, T *output_features,
Int nPlanes, Int input_stride,
Int output_stride, RuleBook _rules,
Int filterVolume) {
RULEBOOKITERATOR((AveragePooling_fp<T, 32, 32><<<32, dim3(32, 32)>>>(
input_features, output_features, nPlanes, input_stride, output_stride,
rbB, nHotB, 1.0 / filterVolume));
, )
Int nPlanes, Int input_stride,
Int output_stride, RuleBook _rules,
Int filterVolume) {
auto application = [&](Int *rbB, Int nHotB, cudaStream_t &stream) -> void {
AveragePooling_fp<T, 32, 32><<<32, dim3(32, 32), 0, stream>>>(
input_features, output_features, nPlanes, input_stride, output_stride,
rbB, nHotB, 1.0 / filterVolume);
};

iterateRuleBook(_rules, application);
}
template <typename T, Int NTX, Int NTY>
__global__ void AveragePooling_bp(T *d_input_features, T *d_output_features,
Int nPlanes, Int input_stride,
Int output_stride, Int *rules, Int nHot,
T alpha) {
Int nPlanes, Int input_stride,
Int output_stride, Int *rules, Int nHot,
T alpha) {
__shared__ Int r[NTY * 2];
for (Int n = blockIdx.x * NTY; n < nHot; n += gridDim.x * NTY) {
{
Int i = threadIdx.x + NTX * threadIdx.y;
if (i < NTY * 2 and i < 2 * (nHot - n))
r[i] = rules[2 * n + i];
r[i] = rules[2 * n + i];
}
__syncthreads();
if (n + threadIdx.y < nHot) {
Int i = r[2 * threadIdx.y] * input_stride;
Int o = r[2 * threadIdx.y + 1] * output_stride;
for (Int plane = threadIdx.x; plane < nPlanes; plane += NTX)
d_input_features[i + plane] += alpha * d_output_features[o + plane];
d_input_features[i + plane] += alpha * d_output_features[o + plane];
}
__syncthreads();
}
}

template <typename T>
void cuda_AveragePooling_BackwardPass(T *d_input_features, T *d_output_features,
Int nPlanes, Int input_stride,
Int output_stride, RuleBook _rules,
Int filterVolume) {
RULEBOOKITERATOR((AveragePooling_bp<T, 32, 32><<<32, dim3(32, 32)>>>(
d_input_features, d_output_features, nPlanes, input_stride, output_stride,
rbB, nHotB, 1.0 / filterVolume));
, )
}









Int nPlanes, Int input_stride,
Int output_stride, RuleBook _rules,
Int filterVolume) {

auto application = [&](Int *rbB, Int nHotB, cudaStream_t &stream) -> void {
AveragePooling_bp<T, 32, 32><<<32, dim3(32, 32), 0, stream>>>(
d_input_features, d_output_features, nPlanes, input_stride, output_stride,
rbB, nHotB, 1.0 / filterVolume);
};

iterateRuleBook(_rules, application);
}

// NTX must be >=2 so r is filled properly
template <typename T, Int NTX, Int NTY>
__global__ void CopyFeaturesHelper_fp(T *input_features, T *output_features, Int * rules,
Int nPlanes, Int nHot) {
__global__ void CopyFeaturesHelper_fp(T *input_features, T *output_features,
Int *rules, Int nPlanes, Int nHot) {
__shared__ Int r[NTY * 2];
for (Int n = blockIdx.x * NTY; n < nHot; n += gridDim.x * NTY) {
{
Int i = threadIdx.x + NTX * threadIdx.y;
if (i < NTY * 2 and i < 2 * (nHot - n))
r[i] = rules[2 * n + i];
r[i] = rules[2 * n + i];
}
__syncthreads();
if (n + threadIdx.y < nHot) {
Int i = r[2 * threadIdx.y+1] * nPlanes;
Int o = r[2 * threadIdx.y ] * nPlanes;
Int i = r[2 * threadIdx.y + 1] * nPlanes;
Int o = r[2 * threadIdx.y] * nPlanes;
for (Int plane = threadIdx.x; plane < nPlanes; plane += NTX)
output_features[o + plane]= input_features[i + plane];
output_features[o + plane] = input_features[i + plane];
}
__syncthreads();
}
}

template <typename T>
void cuda_CopyFeaturesHelper_ForwardPass(T *input_features, T *output_features, Int* rules,
Int nPlanes, Int nHot) {
CopyFeaturesHelper_fp<T, 32, 32><<<32, dim3(32, 32)>>>(
input_features, output_features, rules, nPlanes,
nHot);
void cuda_CopyFeaturesHelper_ForwardPass(T *input_features, T *output_features,
Int *rules, Int nPlanes, Int nHot) {
CopyFeaturesHelper_fp<T, 32, 32><<<32, dim3(32, 32)>>>(
input_features, output_features, rules, nPlanes, nHot);
}
template <typename T, Int NTX, Int NTY>
__global__ void CopyFeaturesHelper_bp(T *d_input_features, T *d_output_features, Int* rules,
Int nPlanes,Int nHot) {
__global__ void CopyFeaturesHelper_bp(T *d_input_features, T *d_output_features,
Int *rules, Int nPlanes, Int nHot) {
__shared__ Int r[NTY * 2];
for (Int n = blockIdx.x * NTY; n < nHot; n += gridDim.x * NTY) {
{
Int i = threadIdx.x + NTX * threadIdx.y;
if (i < NTY * 2 and i < 2 * (nHot - n))
r[i] = rules[2 * n + i];
r[i] = rules[2 * n + i];
}
__syncthreads();
if (n + threadIdx.y < nHot) {
Int i = r[2 * threadIdx.y+1] * nPlanes;
Int i = r[2 * threadIdx.y + 1] * nPlanes;
Int o = r[2 * threadIdx.y] * nPlanes;
for (Int plane = threadIdx.x; plane < nPlanes; plane += NTX)
d_input_features[i + plane] = d_output_features[o + plane];
d_input_features[i + plane] = d_output_features[o + plane];
}
__syncthreads();
}
}

template <typename T>
void cuda_CopyFeaturesHelper_BackwardPass(T *d_input_features, T *d_output_features,
Int* rules, Int nPlanes, Int nHot) {
CopyFeaturesHelper_bp<T, 32, 32><<<32, dim3(32, 32)>>>(
void cuda_CopyFeaturesHelper_BackwardPass(T *d_input_features,
T *d_output_features, Int *rules,
Int nPlanes, Int nHot) {
CopyFeaturesHelper_bp<T, 32, 32><<<32, dim3(32, 32)>>>(
d_input_features, d_output_features, rules, nPlanes, nHot);
}
Loading

0 comments on commit c701936

Please sign in to comment.