From c7400b600d0a0f4905fe69b8bd1bd38b0cca475b Mon Sep 17 00:00:00 2001 From: lishicheng1996 Date: Mon, 5 Aug 2024 15:28:19 +0800 Subject: [PATCH] decoder MMHA kernel support INT8 SCALE_Q_INSTEAD_OF_K and SCALE_P_INSTEAD_OF_V --- .../decoderMaskedMultiheadAttentionTemplate.h | 51 ++++++++++++++----- .../decoderMaskedMultiheadAttentionUtils.h | 9 ++++ 2 files changed, 46 insertions(+), 14 deletions(-) diff --git a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderMaskedMultiheadAttentionTemplate.h b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderMaskedMultiheadAttentionTemplate.h index 6671b46fe..8107cdfe7 100644 --- a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderMaskedMultiheadAttentionTemplate.h +++ b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderMaskedMultiheadAttentionTemplate.h @@ -54,6 +54,11 @@ namespace kernels #define MMHA_FP8_SCALE_P_INSTEAD_OF_V #endif // !defined ENABLE_FP8 +// Apply the INT8 scaling to Q instead of K. +#define MMHA_INT8_SCALE_Q_INSTEAD_OF_K +// Apply the INT8 scaling to P instead of V. +#define MMHA_INT8_SCALE_P_INSTEAD_OF_V + // Below are knobs to extend FP32 accumulation for higher FP16 accuracy // Does not seem to affect the accuracy that much @@ -959,8 +964,12 @@ inline __device__ void Logit_value_fma( float logit = is_mask ? 0.f : reinterpret_cast(logits_smem)[0]; if constexpr (INT8_KV_CACHE) { +#ifdef MMHA_INT8_SCALE_P_INSTEAD_OF_V + out = fma(logit, v_vec, out); +#else V_vec_accum v_vec_ = mul(v_scale, v_vec); out = fma(logit, cast_to_float(v_vec_), out); +#endif // MMHA_INT8_SCALE_P_INSTEAD_OF_V } else if constexpr (FP8_KV_CACHE) { @@ -979,8 +988,12 @@ inline __device__ void Logit_value_fma( Tk logit = is_mask ? Tk(0.f) : logits_smem[0]; if constexpr (INT8_KV_CACHE) { +#ifdef MMHA_INT8_SCALE_P_INSTEAD_OF_V + out = fma(logit, v_vec, out); +#else V_vec_accum v_vec_ = mul(v_scale, v_vec); out = fma(logit, v_vec_, out); +#endif // MMHA_INT8_SCALE_P_INSTEAD_OF_V } else if constexpr (FP8_KV_CACHE) { @@ -1312,9 +1325,18 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske static constexpr bool ENABLE_8BITS_K_CACHE = sizeof(TKcache) == 1; static constexpr bool ENABLE_8BITS_KV_CACHE = sizeof(Tcache) == 1; // FP8 KV Cache. +#ifdef MMHA_FP8_SCALE_Q_INSTEAD_OF_K static constexpr bool FP8_K_CACHE = std::is_same::value; +#else + static constexpr bool FP8_K_CACHE = false; +#endif // MMHA_FP8_SCALE_Q_INSTEAD_OF_K static constexpr bool FP8_KV_CACHE = std::is_same::value; // INT8 KV Cache. +#ifdef MMHA_INT8_SCALE_Q_INSTEAD_OF_K + static constexpr bool INT8_K_CACHE = std::is_same::value; +#else + static constexpr bool INT8_K_CACHE = false; +#endif // MMHA_INT8_SCALE_Q_INSTEAD_OF_K static constexpr bool INT8_KV_CACHE = std::is_same::value; // The size of a warp. @@ -1734,8 +1756,7 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske { // Store the Q values to shared memory. -#ifdef MMHA_FP8_SCALE_Q_INSTEAD_OF_K - if constexpr (FP8_K_CACHE) + if constexpr (FP8_K_CACHE || INT8_K_CACHE) { // There are many more elements from K than elements from Q so we pre-scale Q instead // of scaling all the elements from K. It helps reduce the number of ops. @@ -1743,12 +1764,11 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske zero(scaled_q); if (is_valid_qk_vec) { - scaled_q = mul(k_scale_quant_orig, q); + scaled_q = mul(k_scale_quant_orig, q); } reinterpret_cast(&q_smem[qk_vec_idx])[0] = scaled_q; } else -#endif { // Set padded Dh to 0 for the correctness of QK (when Dh != Dh_Max). Qk_vec_k zero_q; @@ -2012,13 +2032,11 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske // Compute the dot product between Q and K. // Note that dot will convert 8bit vec to the accumulation data type (float by default). float qk_ = 0.f; -#ifdef MMHA_FP8_SCALE_Q_INSTEAD_OF_K - if constexpr (FP8_K_CACHE) + if constexpr (FP8_K_CACHE || INT8_K_CACHE) { qk_ = Qk_dot::dot(q_vec, k_vec) * params.inv_sqrt_dh; } else -#endif // MMHA_FP8_SCALE_Q_INSTEAD_OF_K { if constexpr (ENABLE_8BITS_K_CACHE) { @@ -2158,13 +2176,11 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske // WARNING: ALL THE THREADS OF A WARP MUST ENTER!!! // Note that dot will convert 8bit vec to the accumulation data type (float by default). float qk_ = 0.f; -#ifdef MMHA_FP8_SCALE_Q_INSTEAD_OF_K - if constexpr (FP8_K_CACHE) + if constexpr (FP8_K_CACHE || INT8_K_CACHE) { qk_ = Qk_dot::dot(q_vec, k_vec) * params.inv_sqrt_dh; } else -#endif // MMHA_FP8_SCALE_Q_INSTEAD_OF_K { if constexpr (ENABLE_8BITS_K_CACHE) { @@ -2338,12 +2354,19 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske // Compute the sum. sum = block_sum(&red_smem[WARPS_PER_BLOCK], sum); -// Normalize the logits. + // Normalize the logits. + float logit_scale = 1.0f; #ifdef MMHA_FP8_SCALE_P_INSTEAD_OF_V - float logit_scale = (FP8_KV_CACHE ? kv_scale_quant_orig_f : 1.0f); -#else - float logit_scale = 1.f; + if constexpr (FP8_KV_CACHE) { + logit_scale = kv_scale_quant_orig_f; + } #endif // MMHA_FP8_SCALE_P_INSTEAD_OF_V +#ifdef MMHA_INT8_SCALE_P_INSTEAD_OF_V + if constexpr (INT8_KV_CACHE) { + logit_scale = kv_scale_quant_orig_f; + } +#endif // MMHA_INT8_SCALE_P_INSTEAD_OF_V + float inv_sum = __fdividef(logit_scale, sum + 1.e-6f); int const normlization_loop_end = MULTI_BLOCK_FLAG ? timesteps_per_block : kv_loop_length; diff --git a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttentionUtils.h b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttentionUtils.h index 59ad3dd50..de69fe8c9 100644 --- a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttentionUtils.h +++ b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttentionUtils.h @@ -2431,6 +2431,15 @@ inline __device__ float4 mul(float4 a, int32_t b) return fc; } +/////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +inline __device__ Float4_ mul(float4 a, int32_t b) +{ + float4 fc = mul(a, b); + return reinterpret_cast(fc); +} + //////////////////////////////////////////////////////////////////////////////////////////////////// inline __device__ float sum(float v)