From 61f086df00ff9c3fce11e417745e2fc722b15b90 Mon Sep 17 00:00:00 2001 From: Liu Liu Date: Fri, 15 Dec 2023 21:37:30 -0500 Subject: [PATCH] Add grouped query attention invocation from SDP. --- ...nnc_scaled_dot_product_attention_cpu_ref.c | 22 +++++--- ...scaled_dot_product_attention_flash_attn.cu | 35 ++++++------ test/int/nnc/cublas.tests.c | 56 ++++++++++--------- 3 files changed, 62 insertions(+), 51 deletions(-) diff --git a/lib/nnc/cmd/scaled_dot_product_attention/ccv_nnc_scaled_dot_product_attention_cpu_ref.c b/lib/nnc/cmd/scaled_dot_product_attention/ccv_nnc_scaled_dot_product_attention_cpu_ref.c index a5a74ac4e..fccb7d812 100644 --- a/lib/nnc/cmd/scaled_dot_product_attention/ccv_nnc_scaled_dot_product_attention_cpu_ref.c +++ b/lib/nnc/cmd/scaled_dot_product_attention/ccv_nnc_scaled_dot_product_attention_cpu_ref.c @@ -55,7 +55,10 @@ static int _ccv_nnc_scaled_dot_product_attention_forw(const ccv_nnc_cmd_t cmd, c cdim[0] = cdim[1], cdim[1] = cdim[2], cdim[2] = 1; } assert(qdim[0] == kdim[0] && kdim[0] == vdim[0] && vdim[0] == cdim[0]); - assert(qdim[2] == kdim[2] && kdim[2] == vdim[2] && vdim[2] == cdim[2]); + assert(qdim[2] == cdim[2]); + assert(kdim[2] == vdim[2]); + assert(qdim[2] % kdim[2] == 0); + assert(qdim[2] >= kdim[2]); assert(qdim[3] == kdim[3]); assert(kdim[1] == vdim[1]); assert(cdim[1] == qdim[1]); @@ -117,8 +120,8 @@ static int _ccv_nnc_scaled_dot_product_attention_forw(const ccv_nnc_cmd_t cmd, c for (i[1] = 0; i[1] < qdim[2]; i[1]++) { const float* const qp1 = qp0 + i[1] * qstride[2]; - const float* const kp1 = kp0 + i[1] * kstride[2]; - const float* const vp1 = vp0 + i[1] * vstride[2]; + const float* const kp1 = kp0 + (i[1] % kdim[2]) * kstride[2]; + const float* const vp1 = vp0 + (i[1] % vdim[2]) * vstride[2]; const float* const amp1 = amp && amdim[1] > 1 ? amp0 + i[1] * amstride[1] : amp0; float* const cp1 = cp0 + i[1] * cstride[2]; float* const ssp1 = ssp0 ? ssp0 + i[1] * ssstride[1] : 0; @@ -327,7 +330,10 @@ static int _ccv_nnc_scaled_dot_product_attention_back(const ccv_nnc_cmd_t cmd, c dvdim[0] = dvdim[1], dvdim[1] = dvdim[2], dvdim[2] = 1; } assert(qdim[0] == kdim[0] && kdim[0] == vdim[0] && vdim[0] == gdim[0]); - assert(qdim[2] == kdim[2] && kdim[2] == vdim[2] && vdim[2] == gdim[2]); + assert(qdim[2] == gdim[2]); + assert(kdim[2] == vdim[2]); + assert(qdim[2] % kdim[2] == 0); + assert(qdim[2] >= kdim[2]); assert(qdim[3] == kdim[3]); assert(kdim[1] == vdim[1]); assert(gdim[1] == qdim[1]); @@ -379,12 +385,12 @@ static int _ccv_nnc_scaled_dot_product_attention_back(const ccv_nnc_cmd_t cmd, c for (i[1] = 0; i[1] < qdim[2]; i[1]++) { const float* const qp1 = qp0 + i[1] * qstride[2]; - const float* const kp1 = kp0 + i[1] * kstride[2]; - const float* const vp1 = vp0 + i[1] * vstride[2]; + const float* const kp1 = kp0 + (i[1] % kdim[2]) * kstride[2]; + const float* const vp1 = vp0 + (i[1] % vdim[2]) * vstride[2]; const float* const gp1 = gp0 + i[1] * gstride[2]; float* const dqp1 = dqp0 + i[1] * dqstride[2]; - float* const dkp1 = dkp0 + i[1] * dkstride[2]; - float* const dvp1 = dvp0 + i[1] * dvstride[2]; + float* const dkp1 = dkp0 + (i[1] % dkdim[2]) * dkstride[2]; + float* const dvp1 = dvp0 + (i[1] % dvdim[2]) * dvstride[2]; // Compute Q @ K^T int x, y, k; for (y = 0; y < kdim[1]; y++) diff --git a/lib/nnc/cmd/scaled_dot_product_attention/gpu/ccv_nnc_scaled_dot_product_attention_flash_attn.cu b/lib/nnc/cmd/scaled_dot_product_attention/gpu/ccv_nnc_scaled_dot_product_attention_flash_attn.cu index 5c530b028..e6703231c 100644 --- a/lib/nnc/cmd/scaled_dot_product_attention/gpu/ccv_nnc_scaled_dot_product_attention_flash_attn.cu +++ b/lib/nnc/cmd/scaled_dot_product_attention/gpu/ccv_nnc_scaled_dot_product_attention_flash_attn.cu @@ -79,14 +79,15 @@ static int _ccv_nnc_scaled_dot_product_attention_forw(const ccv_nnc_cmd_t cmd, c int batch_size; int R; int C; - int H; + int Hq; + int Hk; int D; if (q_nd == 3) { batch_size = qdim[1]; assert(batch_size == kdim[1]); R = qdim[2]; C = kdim[2]; - H = 1; + Hq = Hk = 1; D = qdim[3]; assert(D == kdim[3]); } else if (q_nd == 4) { @@ -94,8 +95,10 @@ static int _ccv_nnc_scaled_dot_product_attention_forw(const ccv_nnc_cmd_t cmd, c assert(batch_size == kdim[0]); R = qdim[1]; C = kdim[1]; - H = qdim[2]; - assert(H == kdim[2]); + Hq = qdim[2]; + Hk = kdim[2]; + assert(Hq >= Hk); + assert(Hq % Hk == 0); D = qdim[3]; assert(D == kdim[3]); } @@ -142,15 +145,15 @@ static int _ccv_nnc_scaled_dot_product_attention_forw(const ccv_nnc_cmd_t cmd, c params.q_ptr = q->data.u8; params.k_ptr = k->data.u8; params.v_ptr = v->data.u8; - params.q_row_stride = D * H; - params.k_row_stride = D * H; - params.v_row_stride = D * H; + params.q_row_stride = D * Hq; + params.k_row_stride = D * Hk; + params.v_row_stride = D * Hk; params.q_head_stride = D; params.k_head_stride = D; params.v_head_stride = D; - params.q_batch_stride = R * H * D; - params.k_batch_stride = C * H * D; - params.v_batch_stride = C * H * D; + params.q_batch_stride = R * Hq * D; + params.k_batch_stride = C * Hk * D; + params.v_batch_stride = C * Hk * D; auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; params.seqlen_q = R; params.seqlen_q_rounded = round_multiple(R, 128); @@ -160,13 +163,13 @@ static int _ccv_nnc_scaled_dot_product_attention_forw(const ccv_nnc_cmd_t cmd, c assert(D % 8 == 0); params.d_rounded = round_multiple(D, 32); params.o_ptr = o->data.u8; - params.o_row_stride = D * H; + params.o_row_stride = D * Hq; params.o_head_stride = D; - params.o_batch_stride = R * H * D; + params.o_batch_stride = R * Hq * D; params.b = batch_size; - params.h = H; - params.h_k = H; - params.h_h_k_ratio = 1; + params.h = Hq; + params.h_k = Hk; + params.h_h_k_ratio = Hq / Hk; params.scale_softmax = cmd.info.scaled_dot_product_attention.scale; params.scale_softmax_log2 = cmd.info.scaled_dot_product_attention.scale * M_LOG2E; params.is_causal = cmd.info.scaled_dot_product_attention.is_causal; @@ -177,7 +180,7 @@ static int _ccv_nnc_scaled_dot_product_attention_forw(const ccv_nnc_cmd_t cmd, c params.window_size_left = ccv_max(R, C); params.window_size_right = ccv_max(R, C); params.is_seqlens_k_cumulative = true; - void* workspace = ccv_nnc_stream_context_get_workspace(stream_context, batch_size * H * R * sizeof(float), CCV_TENSOR_GPU_MEMORY); + void* workspace = ccv_nnc_stream_context_get_workspace(stream_context, batch_size * Hq * R * sizeof(float), CCV_TENSOR_GPU_MEMORY); params.softmax_lse_ptr = workspace; // TODO: Support num_splits. // const int block_n = D <= 64 ? 256 : (D <= 128 ? 128 : 64); diff --git a/test/int/nnc/cublas.tests.c b/test/int/nnc/cublas.tests.c index b2b0a9ffa..9b0481d9d 100644 --- a/test/int/nnc/cublas.tests.c +++ b/test/int/nnc/cublas.tests.c @@ -2613,61 +2613,63 @@ TEST_CASE("scaled dot product attention with flash_attn") { GUARD_ELSE_RETURN(ccv_nnc_cmd_ok(CCV_NNC_SCALED_DOT_PRODUCT_ATTENTION_FORWARD, CCV_NNC_BACKEND_GPU_REF)); // Bypass error: variable-sized object may not be initialized -#define num_long_trials 2 +#define num_long_trials 4 #define num_short_trials 2 #define num_trials (num_long_trials + num_short_trials) for (int trial = 0; trial < num_trials; ++trial) { - int B_candidates[num_trials] = { 32, 12, 16, 1 }; - int R_candidates[num_trials] = { 160, 256, 128, 77 }; - int C_candidates[num_trials] = { 128, 128, 128, 128 }; - int H_candidates[num_trials] = { 8, 8, 8, 8 }; - int D_candidates[num_trials] = { 64, 40, 160, 224 }; + int B_candidates[num_trials] = { 32, 12, 16, 1, 2, 15 }; + int R_candidates[num_trials] = { 160, 256, 128, 77, 77, 512 }; + int C_candidates[num_trials] = { 128, 128, 128, 128, 128, 128 }; + int Hq_candidates[num_trials] = { 8, 8, 8, 8, 8, 8 }; + int Hk_candidates[num_trials] = { 8, 8, 8, 8, 2, 4 }; + int D_candidates[num_trials] = { 64, 40, 160, 224, 224, 64 }; int B = B_candidates[trial]; int R = R_candidates[trial]; int C = C_candidates[trial]; - int H = H_candidates[trial]; + int Hq = Hq_candidates[trial]; + int Hk = Hk_candidates[trial]; int D = D_candidates[trial]; float scale = 1.0 / sqrt((float)D); GUARD_ELSE_RETURN(ccv_nnc_cmd_ok(CCV_NNC_SCALED_DOT_PRODUCT_ATTENTION_FORWARD, CCV_NNC_BACKEND_GPU_REF)); - ccv_nnc_tensor_t* const q_tensor = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, B, R, H, D), 0); - ccv_nnc_tensor_t* const k_tensor = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, B, C, H, D), 0); - ccv_nnc_tensor_t* const v_tensor = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, B, C, H, D), 0); + ccv_nnc_tensor_t* const q_tensor = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, B, R, Hq, D), 0); + ccv_nnc_tensor_t* const k_tensor = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, B, C, Hk, D), 0); + ccv_nnc_tensor_t* const v_tensor = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, B, C, Hk, D), 0); - for (int i = 0; i < B * R * H * D; ++i) { - q_tensor->data.f32[i] = (float)(i) / (float)(B * R * H * D); + for (int i = 0; i < B * R * Hq * D; ++i) { + q_tensor->data.f32[i] = (float)(i) / (float)(B * R * Hq * D); } - for (int i = 0; i < B * C * H * D; ++i) { - k_tensor->data.f32[i] = (float)(i) / (float)(B * C * H * D); + for (int i = 0; i < B * C * Hk * D; ++i) { + k_tensor->data.f32[i] = (float)(i) / (float)(B * C * Hk * D); } - for (int i = 0; i < B * C * H * D; ++i) { - v_tensor->data.f32[i] = (float)(i) / (float)(B * C * H * D); + for (int i = 0; i < B * C * Hk * D; ++i) { + v_tensor->data.f32[i] = (float)(i) / (float)(B * C * Hk * D); } - ccv_nnc_tensor_t* const o_tensor = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, B, R, H, D), 0); + ccv_nnc_tensor_t* const o_tensor = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, B, R, Hq, D), 0); ccv_nnc_cmd_exec(CMD_SCALED_DOT_PRODUCT_ATTENTION_FORWARD(scale, 0), ccv_nnc_no_hint, 0, TENSOR_LIST(q_tensor, k_tensor, v_tensor, NULL, NULL, NULL), TENSOR_LIST(o_tensor, NULL), 0); - ccv_nnc_tensor_t* const q_tensor_f16 = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(16F, B, R, H, D), 0); - ccv_nnc_tensor_t* const k_tensor_f16 = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(16F, B, C, H, D), 0); - ccv_nnc_tensor_t* const v_tensor_f16 = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(16F, B, C, H, D), 0); + ccv_nnc_tensor_t* const q_tensor_f16 = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(16F, B, R, Hq, D), 0); + ccv_nnc_tensor_t* const k_tensor_f16 = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(16F, B, C, Hk, D), 0); + ccv_nnc_tensor_t* const v_tensor_f16 = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(16F, B, C, Hk, D), 0); ccv_nnc_cmd_exec(CMD_DATATYPE_CONVERSION_FORWARD(), ccv_nnc_no_hint, 0, TENSOR_LIST(q_tensor, k_tensor, v_tensor), TENSOR_LIST(q_tensor_f16, k_tensor_f16, v_tensor_f16), 0); // Why it there 000 in the beginning of the argument list for GPU_TENSOR_NHWC? - ccv_nnc_tensor_t* const gpu_q_tensor = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 16F, B, R, H, D), 0); - ccv_nnc_tensor_t* const gpu_k_tensor = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 16F, B, C, H, D), 0); - ccv_nnc_tensor_t* const gpu_v_tensor = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 16F, B, C, H, D), 0); - ccv_nnc_tensor_t* const gpu_o_tensor = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 16F, B, R, H, D), 0); + ccv_nnc_tensor_t* const gpu_q_tensor = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 16F, B, R, Hq, D), 0); + ccv_nnc_tensor_t* const gpu_k_tensor = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 16F, B, C, Hk, D), 0); + ccv_nnc_tensor_t* const gpu_v_tensor = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 16F, B, C, Hk, D), 0); + ccv_nnc_tensor_t* const gpu_o_tensor = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 16F, B, R, Hq, D), 0); ccv_nnc_cmd_exec(CMD_DATA_TRANSFER_FORWARD(), ccv_nnc_no_hint, 0, TENSOR_LIST(q_tensor_f16, k_tensor_f16, v_tensor_f16), TENSOR_LIST(gpu_q_tensor, gpu_k_tensor, gpu_v_tensor), 0); ccv_nnc_cmd_exec(CMD_SCALED_DOT_PRODUCT_ATTENTION_FORWARD(scale, 0), ccv_nnc_no_hint, 0, TENSOR_LIST(gpu_q_tensor, gpu_k_tensor, gpu_v_tensor, NULL, NULL, NULL), TENSOR_LIST(gpu_o_tensor, NULL), 0); - ccv_nnc_tensor_t* const copy_of_gpu_o_tensor_f16 = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(16F, B, R, H, D), 0); + ccv_nnc_tensor_t* const copy_of_gpu_o_tensor_f16 = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(16F, B, R, Hq, D), 0); ccv_nnc_cmd_exec(CMD_DATA_TRANSFER_FORWARD(), ccv_nnc_no_hint, 0, TENSOR_LIST(gpu_o_tensor), TENSOR_LIST(copy_of_gpu_o_tensor_f16), 0); - ccv_nnc_tensor_t* const copy_of_gpu_o_tensor = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, B, R, H, D), 0); + ccv_nnc_tensor_t* const copy_of_gpu_o_tensor = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, B, R, Hq, D), 0); ccv_nnc_cmd_exec(CMD_DATATYPE_CONVERSION_FORWARD(), ccv_nnc_no_hint, 0, TENSOR_LIST(copy_of_gpu_o_tensor_f16), TENSOR_LIST(copy_of_gpu_o_tensor), 0); - REQUIRE_ARRAY_EQ_WITH_TOLERANCE(float, copy_of_gpu_o_tensor->data.f32, o_tensor->data.f32, B * R * H * D, 1e-3, "GPU computed output should be the same as CPU computed ones"); + REQUIRE_ARRAY_EQ_WITH_TOLERANCE(float, copy_of_gpu_o_tensor->data.f32, o_tensor->data.f32, B * R * Hq * D, 3e-3, "GPU computed output should be the same as CPU computed ones"); ccv_nnc_tensor_free(o_tensor); ccv_nnc_tensor_free(gpu_o_tensor);