Skip to content

Commit

Permalink
ntt/kernels.cu: template-ize for better versatility.
Browse files Browse the repository at this point in the history
  • Loading branch information
dot-asm committed Dec 13, 2024
1 parent 5cc014d commit f208c21
Showing 1 changed file with 16 additions and 8 deletions.
24 changes: 16 additions & 8 deletions ntt/kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ T bit_rev(T i, unsigned int nbits)

// Permutes the data in an array such that data[i] = data[bit_reverse(i)]
// and data[bit_reverse(i)] = data[i]
template<class fr_t>
__launch_bounds__(1024) __global__
void bit_rev_permutation(fr_t* d_out, const fr_t *d_in, uint32_t lg_domain_size)
{
Expand Down Expand Up @@ -54,13 +55,15 @@ template<typename T>
static __device__ __host__ constexpr uint32_t lg2(T n)
{ uint32_t ret=0; while (n>>=1) ret++; return ret; }

template<unsigned int Z_COUNT>
template<unsigned int Z_COUNT, class fr_t>
__launch_bounds__(192, 2) __global__
void bit_rev_permutation_z(fr_t* out, const fr_t* in, uint32_t lg_domain_size)
{
static_assert((Z_COUNT & (Z_COUNT-1)) == 0, "unvalid Z_COUNT");
const uint32_t LG_Z_COUNT = lg2(Z_COUNT);

extern __shared__ fr_t xchg[][Z_COUNT][Z_COUNT];
extern __shared__ int xchg_bit_rev[];
fr_t (*xchg)[Z_COUNT][Z_COUNT] = reinterpret_cast<decltype(xchg)>(xchg_bit_rev);

uint32_t gid = threadIdx.x / Z_COUNT;
uint32_t idx = threadIdx.x % Z_COUNT;
Expand Down Expand Up @@ -126,6 +129,7 @@ void bit_rev_permutation_z(fr_t* out, const fr_t* in, uint32_t lg_domain_size)
// without "Z_COUNT <= WARP_SZ" compiler spills 128 bytes to stack :-(
}

template<class fr_t>
__device__ __forceinline__
fr_t get_intermediate_root(index_t pow, const fr_t (*roots)[WINDOW_SIZE])
{
Expand Down Expand Up @@ -163,6 +167,7 @@ fr_t get_intermediate_root(index_t pow, const fr_t (*roots)[WINDOW_SIZE])
return root;
}

template<class fr_t>
__launch_bounds__(1024) __global__
void LDE_distribute_powers(fr_t* d_inout, uint32_t lg_domain_size,
uint32_t lg_blowup, bool bitrev,
Expand All @@ -186,14 +191,16 @@ void LDE_distribute_powers(fr_t* d_inout, uint32_t lg_domain_size,
}
}

template<class fr_t>
__launch_bounds__(1024) __global__
void LDE_spread_distribute_powers(fr_t* out, fr_t* in,
const fr_t (*gen_powers)[WINDOW_SIZE],
uint32_t lg_domain_size, uint32_t lg_blowup,
bool perform_shift = true,
bool ext_pow = false)
{
extern __shared__ fr_t exchange[]; // block size
extern __shared__ int xchg_lde_spread[]; // block size
fr_t* exchange = reinterpret_cast<decltype(exchange)>(xchg_lde_spread);

size_t domain_size = (size_t)1 << lg_domain_size;
uint32_t blowup = 1u << lg_blowup;
Expand Down Expand Up @@ -268,6 +275,7 @@ void LDE_spread_distribute_powers(fr_t* out, fr_t* in,
}
}

template<class fr_t>
__device__ __forceinline__
void get_intermediate_roots(fr_t& root0, fr_t& root1,
index_t idx0, index_t idx1,
Expand All @@ -290,7 +298,7 @@ void get_intermediate_roots(fr_t& root0, fr_t& root1,
}
}

template<int z_count>
template<int z_count, class fr_t>
__device__ __forceinline__
void coalesced_load(fr_t r[z_count], const fr_t* inout, index_t idx,
const unsigned int stage)
Expand All @@ -304,12 +312,12 @@ void coalesced_load(fr_t r[z_count], const fr_t* inout, index_t idx,
r[z] = inout[idx];
}

template<int z_count>
template<int z_count, class fr_t>
__device__ __forceinline__
void transpose(fr_t r[z_count])
{
extern __shared__ fr_t shared_exchange[];
fr_t (*xchg)[z_count] = reinterpret_cast<decltype(xchg)>(shared_exchange);
extern __shared__ int xchg_transpose[];
fr_t (*xchg)[z_count] = reinterpret_cast<decltype(xchg)>(xchg_transpose);

const unsigned int x = threadIdx.x & (z_count - 1);
const unsigned int y = threadIdx.x & ~(z_count - 1);
Expand All @@ -325,7 +333,7 @@ void transpose(fr_t r[z_count])
r[z] = xchg[y + x][z];
}

template<int z_count>
template<int z_count, class fr_t>
__device__ __forceinline__
void coalesced_store(fr_t* inout, index_t idx, const fr_t r[z_count],
const unsigned int stage)
Expand Down

0 comments on commit f208c21

Please sign in to comment.