Skip to content

Commit

Permalink
Add grouped query attention invocation from SDP.
Browse files Browse the repository at this point in the history
  • Loading branch information
liuliu committed Dec 16, 2023
1 parent a916bbf commit 61f086d
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 51 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,10 @@ static int _ccv_nnc_scaled_dot_product_attention_forw(const ccv_nnc_cmd_t cmd, c
cdim[0] = cdim[1], cdim[1] = cdim[2], cdim[2] = 1;
}
assert(qdim[0] == kdim[0] && kdim[0] == vdim[0] && vdim[0] == cdim[0]);
assert(qdim[2] == kdim[2] && kdim[2] == vdim[2] && vdim[2] == cdim[2]);
assert(qdim[2] == cdim[2]);
assert(kdim[2] == vdim[2]);
assert(qdim[2] % kdim[2] == 0);
assert(qdim[2] >= kdim[2]);
assert(qdim[3] == kdim[3]);
assert(kdim[1] == vdim[1]);
assert(cdim[1] == qdim[1]);
Expand Down Expand Up @@ -117,8 +120,8 @@ static int _ccv_nnc_scaled_dot_product_attention_forw(const ccv_nnc_cmd_t cmd, c
for (i[1] = 0; i[1] < qdim[2]; i[1]++)
{
const float* const qp1 = qp0 + i[1] * qstride[2];
const float* const kp1 = kp0 + i[1] * kstride[2];
const float* const vp1 = vp0 + i[1] * vstride[2];
const float* const kp1 = kp0 + (i[1] % kdim[2]) * kstride[2];
const float* const vp1 = vp0 + (i[1] % vdim[2]) * vstride[2];
const float* const amp1 = amp && amdim[1] > 1 ? amp0 + i[1] * amstride[1] : amp0;
float* const cp1 = cp0 + i[1] * cstride[2];
float* const ssp1 = ssp0 ? ssp0 + i[1] * ssstride[1] : 0;
Expand Down Expand Up @@ -327,7 +330,10 @@ static int _ccv_nnc_scaled_dot_product_attention_back(const ccv_nnc_cmd_t cmd, c
dvdim[0] = dvdim[1], dvdim[1] = dvdim[2], dvdim[2] = 1;
}
assert(qdim[0] == kdim[0] && kdim[0] == vdim[0] && vdim[0] == gdim[0]);
assert(qdim[2] == kdim[2] && kdim[2] == vdim[2] && vdim[2] == gdim[2]);
assert(qdim[2] == gdim[2]);
assert(kdim[2] == vdim[2]);
assert(qdim[2] % kdim[2] == 0);
assert(qdim[2] >= kdim[2]);
assert(qdim[3] == kdim[3]);
assert(kdim[1] == vdim[1]);
assert(gdim[1] == qdim[1]);
Expand Down Expand Up @@ -379,12 +385,12 @@ static int _ccv_nnc_scaled_dot_product_attention_back(const ccv_nnc_cmd_t cmd, c
for (i[1] = 0; i[1] < qdim[2]; i[1]++)
{
const float* const qp1 = qp0 + i[1] * qstride[2];
const float* const kp1 = kp0 + i[1] * kstride[2];
const float* const vp1 = vp0 + i[1] * vstride[2];
const float* const kp1 = kp0 + (i[1] % kdim[2]) * kstride[2];
const float* const vp1 = vp0 + (i[1] % vdim[2]) * vstride[2];
const float* const gp1 = gp0 + i[1] * gstride[2];
float* const dqp1 = dqp0 + i[1] * dqstride[2];
float* const dkp1 = dkp0 + i[1] * dkstride[2];
float* const dvp1 = dvp0 + i[1] * dvstride[2];
float* const dkp1 = dkp0 + (i[1] % dkdim[2]) * dkstride[2];
float* const dvp1 = dvp0 + (i[1] % dvdim[2]) * dvstride[2];
// Compute Q @ K^T
int x, y, k;
for (y = 0; y < kdim[1]; y++)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,23 +79,26 @@ static int _ccv_nnc_scaled_dot_product_attention_forw(const ccv_nnc_cmd_t cmd, c
int batch_size;
int R;
int C;
int H;
int Hq;
int Hk;
int D;
if (q_nd == 3) {
batch_size = qdim[1];
assert(batch_size == kdim[1]);
R = qdim[2];
C = kdim[2];
H = 1;
Hq = Hk = 1;
D = qdim[3];
assert(D == kdim[3]);
} else if (q_nd == 4) {
batch_size = qdim[0];
assert(batch_size == kdim[0]);
R = qdim[1];
C = kdim[1];
H = qdim[2];
assert(H == kdim[2]);
Hq = qdim[2];
Hk = kdim[2];
assert(Hq >= Hk);
assert(Hq % Hk == 0);
D = qdim[3];
assert(D == kdim[3]);
}
Expand Down Expand Up @@ -142,15 +145,15 @@ static int _ccv_nnc_scaled_dot_product_attention_forw(const ccv_nnc_cmd_t cmd, c
params.q_ptr = q->data.u8;
params.k_ptr = k->data.u8;
params.v_ptr = v->data.u8;
params.q_row_stride = D * H;
params.k_row_stride = D * H;
params.v_row_stride = D * H;
params.q_row_stride = D * Hq;
params.k_row_stride = D * Hk;
params.v_row_stride = D * Hk;
params.q_head_stride = D;
params.k_head_stride = D;
params.v_head_stride = D;
params.q_batch_stride = R * H * D;
params.k_batch_stride = C * H * D;
params.v_batch_stride = C * H * D;
params.q_batch_stride = R * Hq * D;
params.k_batch_stride = C * Hk * D;
params.v_batch_stride = C * Hk * D;
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
params.seqlen_q = R;
params.seqlen_q_rounded = round_multiple(R, 128);
Expand All @@ -160,13 +163,13 @@ static int _ccv_nnc_scaled_dot_product_attention_forw(const ccv_nnc_cmd_t cmd, c
assert(D % 8 == 0);
params.d_rounded = round_multiple(D, 32);
params.o_ptr = o->data.u8;
params.o_row_stride = D * H;
params.o_row_stride = D * Hq;
params.o_head_stride = D;
params.o_batch_stride = R * H * D;
params.o_batch_stride = R * Hq * D;
params.b = batch_size;
params.h = H;
params.h_k = H;
params.h_h_k_ratio = 1;
params.h = Hq;
params.h_k = Hk;
params.h_h_k_ratio = Hq / Hk;
params.scale_softmax = cmd.info.scaled_dot_product_attention.scale;
params.scale_softmax_log2 = cmd.info.scaled_dot_product_attention.scale * M_LOG2E;
params.is_causal = cmd.info.scaled_dot_product_attention.is_causal;
Expand All @@ -177,7 +180,7 @@ static int _ccv_nnc_scaled_dot_product_attention_forw(const ccv_nnc_cmd_t cmd, c
params.window_size_left = ccv_max(R, C);
params.window_size_right = ccv_max(R, C);
params.is_seqlens_k_cumulative = true;
void* workspace = ccv_nnc_stream_context_get_workspace(stream_context, batch_size * H * R * sizeof(float), CCV_TENSOR_GPU_MEMORY);
void* workspace = ccv_nnc_stream_context_get_workspace(stream_context, batch_size * Hq * R * sizeof(float), CCV_TENSOR_GPU_MEMORY);
params.softmax_lse_ptr = workspace;
// TODO: Support num_splits.
// const int block_n = D <= 64 ? 256 : (D <= 128 ? 128 : 64);
Expand Down
56 changes: 29 additions & 27 deletions test/int/nnc/cublas.tests.c
Original file line number Diff line number Diff line change
Expand Up @@ -2613,61 +2613,63 @@ TEST_CASE("scaled dot product attention with flash_attn")
{
GUARD_ELSE_RETURN(ccv_nnc_cmd_ok(CCV_NNC_SCALED_DOT_PRODUCT_ATTENTION_FORWARD, CCV_NNC_BACKEND_GPU_REF));
// Bypass error: variable-sized object may not be initialized
#define num_long_trials 2
#define num_long_trials 4
#define num_short_trials 2
#define num_trials (num_long_trials + num_short_trials)

for (int trial = 0; trial < num_trials; ++trial) {
int B_candidates[num_trials] = { 32, 12, 16, 1 };
int R_candidates[num_trials] = { 160, 256, 128, 77 };
int C_candidates[num_trials] = { 128, 128, 128, 128 };
int H_candidates[num_trials] = { 8, 8, 8, 8 };
int D_candidates[num_trials] = { 64, 40, 160, 224 };
int B_candidates[num_trials] = { 32, 12, 16, 1, 2, 15 };
int R_candidates[num_trials] = { 160, 256, 128, 77, 77, 512 };
int C_candidates[num_trials] = { 128, 128, 128, 128, 128, 128 };
int Hq_candidates[num_trials] = { 8, 8, 8, 8, 8, 8 };
int Hk_candidates[num_trials] = { 8, 8, 8, 8, 2, 4 };
int D_candidates[num_trials] = { 64, 40, 160, 224, 224, 64 };

int B = B_candidates[trial];
int R = R_candidates[trial];
int C = C_candidates[trial];
int H = H_candidates[trial];
int Hq = Hq_candidates[trial];
int Hk = Hk_candidates[trial];
int D = D_candidates[trial];
float scale = 1.0 / sqrt((float)D);

GUARD_ELSE_RETURN(ccv_nnc_cmd_ok(CCV_NNC_SCALED_DOT_PRODUCT_ATTENTION_FORWARD, CCV_NNC_BACKEND_GPU_REF));
ccv_nnc_tensor_t* const q_tensor = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, B, R, H, D), 0);
ccv_nnc_tensor_t* const k_tensor = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, B, C, H, D), 0);
ccv_nnc_tensor_t* const v_tensor = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, B, C, H, D), 0);
ccv_nnc_tensor_t* const q_tensor = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, B, R, Hq, D), 0);
ccv_nnc_tensor_t* const k_tensor = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, B, C, Hk, D), 0);
ccv_nnc_tensor_t* const v_tensor = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, B, C, Hk, D), 0);

for (int i = 0; i < B * R * H * D; ++i) {
q_tensor->data.f32[i] = (float)(i) / (float)(B * R * H * D);
for (int i = 0; i < B * R * Hq * D; ++i) {
q_tensor->data.f32[i] = (float)(i) / (float)(B * R * Hq * D);
}
for (int i = 0; i < B * C * H * D; ++i) {
k_tensor->data.f32[i] = (float)(i) / (float)(B * C * H * D);
for (int i = 0; i < B * C * Hk * D; ++i) {
k_tensor->data.f32[i] = (float)(i) / (float)(B * C * Hk * D);
}
for (int i = 0; i < B * C * H * D; ++i) {
v_tensor->data.f32[i] = (float)(i) / (float)(B * C * H * D);
for (int i = 0; i < B * C * Hk * D; ++i) {
v_tensor->data.f32[i] = (float)(i) / (float)(B * C * Hk * D);
}

ccv_nnc_tensor_t* const o_tensor = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, B, R, H, D), 0);
ccv_nnc_tensor_t* const o_tensor = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, B, R, Hq, D), 0);
ccv_nnc_cmd_exec(CMD_SCALED_DOT_PRODUCT_ATTENTION_FORWARD(scale, 0), ccv_nnc_no_hint, 0, TENSOR_LIST(q_tensor, k_tensor, v_tensor, NULL, NULL, NULL), TENSOR_LIST(o_tensor, NULL), 0);
ccv_nnc_tensor_t* const q_tensor_f16 = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(16F, B, R, H, D), 0);
ccv_nnc_tensor_t* const k_tensor_f16 = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(16F, B, C, H, D), 0);
ccv_nnc_tensor_t* const v_tensor_f16 = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(16F, B, C, H, D), 0);
ccv_nnc_tensor_t* const q_tensor_f16 = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(16F, B, R, Hq, D), 0);
ccv_nnc_tensor_t* const k_tensor_f16 = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(16F, B, C, Hk, D), 0);
ccv_nnc_tensor_t* const v_tensor_f16 = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(16F, B, C, Hk, D), 0);
ccv_nnc_cmd_exec(CMD_DATATYPE_CONVERSION_FORWARD(), ccv_nnc_no_hint, 0, TENSOR_LIST(q_tensor, k_tensor, v_tensor), TENSOR_LIST(q_tensor_f16, k_tensor_f16, v_tensor_f16), 0);

// Why it there 000 in the beginning of the argument list for GPU_TENSOR_NHWC?
ccv_nnc_tensor_t* const gpu_q_tensor = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 16F, B, R, H, D), 0);
ccv_nnc_tensor_t* const gpu_k_tensor = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 16F, B, C, H, D), 0);
ccv_nnc_tensor_t* const gpu_v_tensor = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 16F, B, C, H, D), 0);
ccv_nnc_tensor_t* const gpu_o_tensor = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 16F, B, R, H, D), 0);
ccv_nnc_tensor_t* const gpu_q_tensor = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 16F, B, R, Hq, D), 0);
ccv_nnc_tensor_t* const gpu_k_tensor = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 16F, B, C, Hk, D), 0);
ccv_nnc_tensor_t* const gpu_v_tensor = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 16F, B, C, Hk, D), 0);
ccv_nnc_tensor_t* const gpu_o_tensor = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 16F, B, R, Hq, D), 0);
ccv_nnc_cmd_exec(CMD_DATA_TRANSFER_FORWARD(), ccv_nnc_no_hint, 0, TENSOR_LIST(q_tensor_f16, k_tensor_f16, v_tensor_f16), TENSOR_LIST(gpu_q_tensor, gpu_k_tensor, gpu_v_tensor), 0);

ccv_nnc_cmd_exec(CMD_SCALED_DOT_PRODUCT_ATTENTION_FORWARD(scale, 0), ccv_nnc_no_hint, 0, TENSOR_LIST(gpu_q_tensor, gpu_k_tensor, gpu_v_tensor, NULL, NULL, NULL), TENSOR_LIST(gpu_o_tensor, NULL), 0);

ccv_nnc_tensor_t* const copy_of_gpu_o_tensor_f16 = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(16F, B, R, H, D), 0);
ccv_nnc_tensor_t* const copy_of_gpu_o_tensor_f16 = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(16F, B, R, Hq, D), 0);
ccv_nnc_cmd_exec(CMD_DATA_TRANSFER_FORWARD(), ccv_nnc_no_hint, 0, TENSOR_LIST(gpu_o_tensor), TENSOR_LIST(copy_of_gpu_o_tensor_f16), 0);
ccv_nnc_tensor_t* const copy_of_gpu_o_tensor = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, B, R, H, D), 0);
ccv_nnc_tensor_t* const copy_of_gpu_o_tensor = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, B, R, Hq, D), 0);
ccv_nnc_cmd_exec(CMD_DATATYPE_CONVERSION_FORWARD(), ccv_nnc_no_hint, 0, TENSOR_LIST(copy_of_gpu_o_tensor_f16), TENSOR_LIST(copy_of_gpu_o_tensor), 0);

REQUIRE_ARRAY_EQ_WITH_TOLERANCE(float, copy_of_gpu_o_tensor->data.f32, o_tensor->data.f32, B * R * H * D, 1e-3, "GPU computed output should be the same as CPU computed ones");
REQUIRE_ARRAY_EQ_WITH_TOLERANCE(float, copy_of_gpu_o_tensor->data.f32, o_tensor->data.f32, B * R * Hq * D, 3e-3, "GPU computed output should be the same as CPU computed ones");

ccv_nnc_tensor_free(o_tensor);
ccv_nnc_tensor_free(gpu_o_tensor);
Expand Down

0 comments on commit 61f086d

Please sign in to comment.