Skip to content

Commit

Permalink
Fix is_causal should match between GPU / CPU on sdp.
Browse files Browse the repository at this point in the history
  • Loading branch information
liuliu committed Jan 1, 2024
1 parent 75afa90 commit ee57ca5
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -156,21 +156,19 @@ static int _ccv_nnc_scaled_dot_product_attention_forw(const ccv_nnc_cmd_t cmd, c
// Compute softmax on qk.
if (is_causal)
{
for (y = 0; y < ccv_min(x, kdim[1]); y++)
const int x_end = ccv_max(x - qdim[1] + kdim[1] + 1, 0);
for (y = x_end; y < kdim[1]; y++)
qk0[y] = 0;
if (x < kdim[1])
{
double maxval = qk0[x];
for (y = x + 1; y < kdim[1]; y++)
if (qk0[y] > maxval)
maxval = qk0[y];
double sumval = 0;
for (y = x; y < kdim[1]; y++)
sumval += (qk0[y] = expf(qk0[y] - maxval));
sumval = 1.0 / sumval;
for (y = x; y < kdim[1]; y++)
qk0[y] *= sumval;
}
double maxval = qk0[0];
for (y = 1; y < x_end; y++)
if (qk0[y] > maxval)
maxval = qk0[y];
double sumval = 0;
for (y = 0; y < x_end; y++)
sumval += (qk0[y] = expf(qk0[y] - maxval));
sumval = 1.0 / sumval;
for (y = 0; y < x_end; y++)
qk0[y] *= sumval;
} else {
double maxval = qk0[0];
for (y = 1; y < kdim[1]; y++)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ static int _ccv_nnc_scaled_dot_product_attention_forw(const ccv_nnc_cmd_t cmd, c
params.rp_dropout = 1;
params.scale_softmax_rp_dropout = params.scale_softmax;
params.window_size_left = ccv_max(R, C);
params.window_size_right = ccv_max(R, C);
params.window_size_right = params.is_causal ? 0 : ccv_max(R, C);
params.is_seqlens_k_cumulative = true;
void* workspace = ccv_nnc_stream_context_get_workspace(stream_context, batch_size * Hq * R * sizeof(float), CCV_TENSOR_GPU_MEMORY);
params.softmax_lse_ptr = workspace;
Expand Down
6 changes: 4 additions & 2 deletions test/int/nnc/cublas.tests.c
Original file line number Diff line number Diff line change
Expand Up @@ -2624,13 +2624,15 @@ TEST_CASE("scaled dot product attention with flash_attn")
int Hq_candidates[num_trials] = { 8, 8, 8, 8, 8, 8 };
int Hk_candidates[num_trials] = { 8, 8, 8, 8, 2, 4 };
int D_candidates[num_trials] = { 64, 40, 160, 224, 224, 64 };
int is_causal_candidates[num_trials] = { 1, 0, 1, 1, 0, 0 };

int B = B_candidates[trial];
int R = R_candidates[trial];
int C = C_candidates[trial];
int Hq = Hq_candidates[trial];
int Hk = Hk_candidates[trial];
int D = D_candidates[trial];
int is_causal = is_causal_candidates[trial];
float scale = 1.0 / sqrt((float)D);

GUARD_ELSE_RETURN(ccv_nnc_cmd_ok(CCV_NNC_SCALED_DOT_PRODUCT_ATTENTION_FORWARD, CCV_NNC_BACKEND_GPU_REF));
Expand All @@ -2649,7 +2651,7 @@ TEST_CASE("scaled dot product attention with flash_attn")
}

ccv_nnc_tensor_t* const o_tensor = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, B, R, Hq, D), 0);
ccv_nnc_cmd_exec(CMD_SCALED_DOT_PRODUCT_ATTENTION_FORWARD(scale, 0), ccv_nnc_no_hint, 0, TENSOR_LIST(q_tensor, k_tensor, v_tensor, NULL, NULL, NULL), TENSOR_LIST(o_tensor, NULL), 0);
ccv_nnc_cmd_exec(CMD_SCALED_DOT_PRODUCT_ATTENTION_FORWARD(scale, is_causal), ccv_nnc_no_hint, 0, TENSOR_LIST(q_tensor, k_tensor, v_tensor, NULL, NULL, NULL), TENSOR_LIST(o_tensor, NULL), 0);
ccv_nnc_tensor_t* const q_tensor_f16 = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(16F, B, R, Hq, D), 0);
ccv_nnc_tensor_t* const k_tensor_f16 = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(16F, B, C, Hk, D), 0);
ccv_nnc_tensor_t* const v_tensor_f16 = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(16F, B, C, Hk, D), 0);
Expand All @@ -2662,7 +2664,7 @@ TEST_CASE("scaled dot product attention with flash_attn")
ccv_nnc_tensor_t* const gpu_o_tensor = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 16F, B, R, Hq, D), 0);
ccv_nnc_cmd_exec(CMD_DATA_TRANSFER_FORWARD(), ccv_nnc_no_hint, 0, TENSOR_LIST(q_tensor_f16, k_tensor_f16, v_tensor_f16), TENSOR_LIST(gpu_q_tensor, gpu_k_tensor, gpu_v_tensor), 0);

ccv_nnc_cmd_exec(CMD_SCALED_DOT_PRODUCT_ATTENTION_FORWARD(scale, 0), ccv_nnc_no_hint, 0, TENSOR_LIST(gpu_q_tensor, gpu_k_tensor, gpu_v_tensor, NULL, NULL, NULL), TENSOR_LIST(gpu_o_tensor, NULL), 0);
ccv_nnc_cmd_exec(CMD_SCALED_DOT_PRODUCT_ATTENTION_FORWARD(scale, is_causal), ccv_nnc_no_hint, 0, TENSOR_LIST(gpu_q_tensor, gpu_k_tensor, gpu_v_tensor, NULL, NULL, NULL), TENSOR_LIST(gpu_o_tensor, NULL), 0);

ccv_nnc_tensor_t* const copy_of_gpu_o_tensor_f16 = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(16F, B, R, Hq, D), 0);
ccv_nnc_cmd_exec(CMD_DATA_TRANSFER_FORWARD(), ccv_nnc_no_hint, 0, TENSOR_LIST(gpu_o_tensor), TENSOR_LIST(copy_of_gpu_o_tensor_f16), 0);
Expand Down

0 comments on commit ee57ca5

Please sign in to comment.