Skip to content

Commit

Permalink
ff/shfl.cuh: balance size vs. performance.
Browse files Browse the repository at this point in the history
  • Loading branch information
dot-asm committed Nov 27, 2024
1 parent b2ceb63 commit 83b5d2a
Showing 1 changed file with 38 additions and 2 deletions.
40 changes: 38 additions & 2 deletions ff/shfl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ static T shfl_up(const T& src, unsigned int off)
}

template<class T> __device__ __forceinline__
static T& add_up(T& x_lane, unsigned limit = WARP_SZ)
static T& add_up(T& x_lane, unsigned limit)
{
const unsigned laneid = threadIdx.x % WARP_SZ;

Expand All @@ -76,7 +76,25 @@ static T& add_up(T& x_lane, unsigned limit = WARP_SZ)
}

template<class T> __device__ __forceinline__
static T& mul_up(T& x_lane, unsigned limit = WARP_SZ)
static T& add_up(T& x_lane)
{
if (sizeof(T) > 16)
return add_up(x_lane, WARP_SZ);

const unsigned laneid = threadIdx.x % WARP_SZ;

#pragma unroll
for (unsigned off = 1; off < WARP_SZ; off <<= 1) {
auto temp = shfl_up(x_lane, off);
temp += x_lane;
x_lane = T::csel(x_lane, temp, laneid < off);
}

return x_lane;
}

template<class T> __device__ __forceinline__
static T& mul_up(T& x_lane, unsigned limit)
{
const unsigned laneid = threadIdx.x % WARP_SZ;

Expand All @@ -91,4 +109,22 @@ static T& mul_up(T& x_lane, unsigned limit = WARP_SZ)

return x_lane;
}

template<class T> __device__ __forceinline__
static T& mul_up(T& x_lane)
{
if (sizeof(T) > 4)
return mul_up(x_lane, WARP_SZ);

const unsigned laneid = threadIdx.x % WARP_SZ;

#pragma unroll
for (unsigned off = 1; off < WARP_SZ; off <<= 1) {
auto temp = shfl_up(x_lane, off);
temp *= x_lane;
x_lane = T::csel(x_lane, temp, laneid < off);
}

return x_lane;
}
#endif

0 comments on commit 83b5d2a

Please sign in to comment.