diff --git a/ntt/kernels.cu b/ntt/kernels.cu index 50ed0a7..0b36033 100644 --- a/ntt/kernels.cu +++ b/ntt/kernels.cu @@ -23,6 +23,7 @@ T bit_rev(T i, unsigned int nbits) // Permutes the data in an array such that data[i] = data[bit_reverse(i)] // and data[bit_reverse(i)] = data[i] +template __launch_bounds__(1024) __global__ void bit_rev_permutation(fr_t* d_out, const fr_t *d_in, uint32_t lg_domain_size) { @@ -54,13 +55,15 @@ template static __device__ __host__ constexpr uint32_t lg2(T n) { uint32_t ret=0; while (n>>=1) ret++; return ret; } -template +template __launch_bounds__(192, 2) __global__ void bit_rev_permutation_z(fr_t* out, const fr_t* in, uint32_t lg_domain_size) { + static_assert((Z_COUNT & (Z_COUNT-1)) == 0, "unvalid Z_COUNT"); const uint32_t LG_Z_COUNT = lg2(Z_COUNT); - extern __shared__ fr_t xchg[][Z_COUNT][Z_COUNT]; + extern __shared__ int xchg_bit_rev[]; + fr_t (*xchg)[Z_COUNT][Z_COUNT] = reinterpret_cast(xchg_bit_rev); uint32_t gid = threadIdx.x / Z_COUNT; uint32_t idx = threadIdx.x % Z_COUNT; @@ -126,6 +129,7 @@ void bit_rev_permutation_z(fr_t* out, const fr_t* in, uint32_t lg_domain_size) // without "Z_COUNT <= WARP_SZ" compiler spills 128 bytes to stack :-( } +template __device__ __forceinline__ fr_t get_intermediate_root(index_t pow, const fr_t (*roots)[WINDOW_SIZE]) { @@ -163,6 +167,7 @@ fr_t get_intermediate_root(index_t pow, const fr_t (*roots)[WINDOW_SIZE]) return root; } +template __launch_bounds__(1024) __global__ void LDE_distribute_powers(fr_t* d_inout, uint32_t lg_domain_size, uint32_t lg_blowup, bool bitrev, @@ -186,6 +191,7 @@ void LDE_distribute_powers(fr_t* d_inout, uint32_t lg_domain_size, } } +template __launch_bounds__(1024) __global__ void LDE_spread_distribute_powers(fr_t* out, fr_t* in, const fr_t (*gen_powers)[WINDOW_SIZE], @@ -193,7 +199,8 @@ void LDE_spread_distribute_powers(fr_t* out, fr_t* in, bool perform_shift = true, bool ext_pow = false) { - extern __shared__ fr_t exchange[]; // block size + extern __shared__ int xchg_lde_spread[]; // block size + fr_t* exchange = reinterpret_cast(xchg_lde_spread); size_t domain_size = (size_t)1 << lg_domain_size; uint32_t blowup = 1u << lg_blowup; @@ -268,6 +275,7 @@ void LDE_spread_distribute_powers(fr_t* out, fr_t* in, } } +template __device__ __forceinline__ void get_intermediate_roots(fr_t& root0, fr_t& root1, index_t idx0, index_t idx1, @@ -290,7 +298,7 @@ void get_intermediate_roots(fr_t& root0, fr_t& root1, } } -template +template __device__ __forceinline__ void coalesced_load(fr_t r[z_count], const fr_t* inout, index_t idx, const unsigned int stage) @@ -304,12 +312,12 @@ void coalesced_load(fr_t r[z_count], const fr_t* inout, index_t idx, r[z] = inout[idx]; } -template +template __device__ __forceinline__ void transpose(fr_t r[z_count]) { - extern __shared__ fr_t shared_exchange[]; - fr_t (*xchg)[z_count] = reinterpret_cast(shared_exchange); + extern __shared__ int xchg_transpose[]; + fr_t (*xchg)[z_count] = reinterpret_cast(xchg_transpose); const unsigned int x = threadIdx.x & (z_count - 1); const unsigned int y = threadIdx.x & ~(z_count - 1); @@ -325,7 +333,7 @@ void transpose(fr_t r[z_count]) r[z] = xchg[y + x][z]; } -template +template __device__ __forceinline__ void coalesced_store(fr_t* inout, index_t idx, const fr_t r[z_count], const unsigned int stage)