From fe99c482150380422b9215233f5eb271efee4a79 Mon Sep 17 00:00:00 2001 From: Liu Liu Date: Wed, 30 Oct 2024 01:17:03 -0400 Subject: [PATCH] Fix various issues related to running backprop of sdpa on fp16. --- lib/nnc/mfa/ccv_nnc_mfa_attention.cpp | 55 ++++++------ lib/nnc/mfa/v2/AttentionDescriptor.cpp | 21 +++-- test/int/nnc/mpsblas.tests.c | 117 +++++++++++++++++++++++++ 3 files changed, 160 insertions(+), 33 deletions(-) diff --git a/lib/nnc/mfa/ccv_nnc_mfa_attention.cpp b/lib/nnc/mfa/ccv_nnc_mfa_attention.cpp index 9e97cc89c..87ebb9c4e 100644 --- a/lib/nnc/mfa/ccv_nnc_mfa_attention.cpp +++ b/lib/nnc/mfa/ccv_nnc_mfa_attention.cpp @@ -128,7 +128,7 @@ void ccv_nnc_mfa_encode_attention(mfa::context* context, ccv_nnc_mfa_attention_p if (tensors[5]) { encoder->useResource(tensors[5], MTL::ResourceUsageRead | MTL::ResourceUsageWrite); } - encoder->setBuffer(tensors[0], tensor_offsets[0], AttentionOperand(AttentionOperand::Q).bufferIndex()); + encoder->setBuffer(tensors[0], tensor_offsets[0], AttentionOperand(AttentionOperand::Q).bufferIndex()); encoder->setBuffer(tensors[1], tensor_offsets[1], AttentionOperand(AttentionOperand::K).bufferIndex()); encoder->setBuffer(tensors[2], tensor_offsets[2], AttentionOperand(AttentionOperand::V).bufferIndex()); if (attentionDesc.lowPrecisionInputs) { @@ -198,6 +198,15 @@ void ccv_nnc_mfa_encode_attention(mfa::context* context, ccv_nnc_mfa_attention_p auto backwardKeyValueKernel = backwardKeyValuePipelineValue->kernel; auto backwardKeyValuePipeline = backwardKeyValuePipelineValue->pipeline; + auto scratch_size = 0; + if (attentionDesc.lowPrecisionInputs) { + // Need scratch space for FP16 output. + scratch_size += sizeof(float) * (hash.R + hash.C * 2) * hash.D * hash.Hq * attentionDesc.batchDimension; + } + // Need scratch space for D. + scratch_size += sizeof(float) * hash.R * hash.Hq * attentionDesc.batchDimension; + auto scratch = context->request_scratch(scratch_size); + // Allocate a new command. auto backwardQueryEncoder = command_batch->startCommand(); backwardQueryEncoder->setComputePipelineState(backwardQueryPipeline.get()); @@ -210,14 +219,6 @@ void ccv_nnc_mfa_encode_attention(mfa::context* context, ccv_nnc_mfa_attention_p backwardQueryEncoder->useResource(tensors[3], MTL::ResourceUsageRead); backwardQueryEncoder->useResource(tensors[4], MTL::ResourceUsageRead); backwardQueryEncoder->useResource(tensors[5], MTL::ResourceUsageRead); - auto scratch_size = 0; - if (attentionDesc.lowPrecisionInputs) { - // Need scratch space for FP16 output. - scratch_size += sizeof(float) * hash.R * hash.D * hash.Hq * attentionDesc.batchDimension * 3; - } - // Need scratch space for D. - scratch_size += sizeof(float) * hash.R * hash.Hq * attentionDesc.batchDimension; - auto scratch = context->request_scratch(scratch_size); backwardQueryEncoder->useResource(scratch, MTL::ResourceUsageRead | MTL::ResourceUsageWrite); if (!attentionDesc.lowPrecisionInputs) { backwardQueryEncoder->useResource(tensors[6], MTL::ResourceUsageWrite); @@ -231,7 +232,7 @@ void ccv_nnc_mfa_encode_attention(mfa::context* context, ccv_nnc_mfa_attention_p backwardQueryEncoder->setBuffer(tensors[5], tensor_offsets[5], AttentionOperand(AttentionOperand::dO).bufferIndex()); if (attentionDesc.lowPrecisionInputs) { backwardQueryEncoder->setBuffer(scratch, 0, AttentionOperand(AttentionOperand::dQ).bufferIndex()); - backwardQueryEncoder->setBuffer(scratch, sizeof(float) * hash.R * hash.D * hash.Hq * attentionDesc.batchDimension * 3, AttentionOperand(AttentionOperand::D).bufferIndex()); + backwardQueryEncoder->setBuffer(scratch, sizeof(float) * (hash.R + hash.C * 2) * hash.D * hash.Hq * attentionDesc.batchDimension, AttentionOperand(AttentionOperand::D).bufferIndex()); } else { backwardQueryEncoder->setBuffer(tensors[6], tensor_offsets[6], AttentionOperand(AttentionOperand::dQ).bufferIndex()); backwardQueryEncoder->setBuffer(scratch, 0, AttentionOperand(AttentionOperand::D).bufferIndex()); @@ -268,24 +269,24 @@ void ccv_nnc_mfa_encode_attention(mfa::context* context, ccv_nnc_mfa_attention_p backwardKeyValueEncoder->useResource(tensors[8], MTL::ResourceUsageWrite); } - backwardQueryEncoder->setBuffer(tensors[0], tensor_offsets[0], AttentionOperand(AttentionOperand::Q).bufferIndex()); - backwardQueryEncoder->setBuffer(tensors[1], tensor_offsets[1], AttentionOperand(AttentionOperand::K).bufferIndex()); - backwardQueryEncoder->setBuffer(tensors[2], tensor_offsets[2], AttentionOperand(AttentionOperand::V).bufferIndex()); - backwardQueryEncoder->setBuffer(tensors[3], tensor_offsets[3], AttentionOperand(AttentionOperand::O).bufferIndex()); - backwardQueryEncoder->setBuffer(tensors[4], tensor_offsets[4], AttentionOperand(AttentionOperand::L).bufferIndex()); - backwardQueryEncoder->setBuffer(tensors[5], tensor_offsets[5], AttentionOperand(AttentionOperand::dO).bufferIndex()); + backwardKeyValueEncoder->setBuffer(tensors[0], tensor_offsets[0], AttentionOperand(AttentionOperand::Q).bufferIndex()); + backwardKeyValueEncoder->setBuffer(tensors[1], tensor_offsets[1], AttentionOperand(AttentionOperand::K).bufferIndex()); + backwardKeyValueEncoder->setBuffer(tensors[2], tensor_offsets[2], AttentionOperand(AttentionOperand::V).bufferIndex()); + backwardKeyValueEncoder->setBuffer(tensors[3], tensor_offsets[3], AttentionOperand(AttentionOperand::O).bufferIndex()); + backwardKeyValueEncoder->setBuffer(tensors[4], tensor_offsets[4], AttentionOperand(AttentionOperand::L).bufferIndex()); + backwardKeyValueEncoder->setBuffer(tensors[5], tensor_offsets[5], AttentionOperand(AttentionOperand::dO).bufferIndex()); if (attentionDesc.lowPrecisionInputs) { - backwardQueryEncoder->setBuffer(scratch, sizeof(float) * hash.R * hash.D * hash.Hq * attentionDesc.batchDimension, AttentionOperand(AttentionOperand::dK).bufferIndex()); - backwardQueryEncoder->setBuffer(scratch, sizeof(float) * hash.R * hash.D * hash.Hq * attentionDesc.batchDimension * 2, AttentionOperand(AttentionOperand::dV).bufferIndex()); - backwardQueryEncoder->setBuffer(scratch, sizeof(float) * hash.R * hash.D * hash.Hq * attentionDesc.batchDimension * 3, AttentionOperand(AttentionOperand::D).bufferIndex()); + backwardKeyValueEncoder->setBuffer(scratch, sizeof(float) * hash.R * hash.D * hash.Hq * attentionDesc.batchDimension, AttentionOperand(AttentionOperand::dK).bufferIndex()); + backwardKeyValueEncoder->setBuffer(scratch, sizeof(float) * (hash.R + hash.C) * hash.D * hash.Hq * attentionDesc.batchDimension, AttentionOperand(AttentionOperand::dV).bufferIndex()); + backwardKeyValueEncoder->setBuffer(scratch, sizeof(float) * (hash.R + hash.C * 2) * hash.D * hash.Hq * attentionDesc.batchDimension, AttentionOperand(AttentionOperand::D).bufferIndex()); } else { - backwardQueryEncoder->setBuffer(scratch, 0, AttentionOperand(AttentionOperand::D).bufferIndex()); - backwardQueryEncoder->setBuffer(tensors[7], tensor_offsets[7], AttentionOperand(AttentionOperand::dK).bufferIndex()); - backwardQueryEncoder->setBuffer(tensors[8], tensor_offsets[8], AttentionOperand(AttentionOperand::dV).bufferIndex()); + backwardKeyValueEncoder->setBuffer(scratch, 0, AttentionOperand(AttentionOperand::D).bufferIndex()); + backwardKeyValueEncoder->setBuffer(tensors[7], tensor_offsets[7], AttentionOperand(AttentionOperand::dK).bufferIndex()); + backwardKeyValueEncoder->setBuffer(tensors[8], tensor_offsets[8], AttentionOperand(AttentionOperand::dV).bufferIndex()); } MTL::Size backwardKeyValueGridSize - (ceilDivide(int64_t(hash.R), backwardKeyValueKernel->blockDimensions[0]), + (ceilDivide(int64_t(hash.C), backwardKeyValueKernel->blockDimensions[0]), hash.Hq, attentionDesc.batchDimension); MTL::Size backwardKeyValueGroupSize @@ -315,12 +316,14 @@ void ccv_nnc_mfa_encode_attention(mfa::context* context, ccv_nnc_mfa_attention_p tensor_offsets[6] }; ccv_nnc_mfa_encode_cast(context, cast_params, command_batch, cast_tensors, cast_tensor_offsets); - cast_tensors[1] = tensors[7]; + cast_params.length = hash.C * hash.D * hash.Hq * attentionDesc.batchDimension; + ccv_nnc_mfa_prepare_cast(context, cast_params); + cast_tensors[1] = tensors[7]; cast_tensor_offsets[0] = sizeof(float) * hash.R * hash.D * hash.Hq * attentionDesc.batchDimension; cast_tensor_offsets[1] = tensor_offsets[7]; ccv_nnc_mfa_encode_cast(context, cast_params, command_batch, cast_tensors, cast_tensor_offsets); - cast_tensors[1] = tensors[8]; - cast_tensor_offsets[0] = sizeof(float) * hash.R * hash.D * hash.Hq * attentionDesc.batchDimension * 2; + cast_tensors[1] = tensors[8]; + cast_tensor_offsets[0] = sizeof(float) * (hash.R + hash.C) * hash.D * hash.Hq * attentionDesc.batchDimension; cast_tensor_offsets[1] = tensor_offsets[8]; ccv_nnc_mfa_encode_cast(context, cast_params, command_batch, cast_tensors, cast_tensor_offsets); } diff --git a/lib/nnc/mfa/v2/AttentionDescriptor.cpp b/lib/nnc/mfa/v2/AttentionDescriptor.cpp index 5f98864fa..567d7d171 100644 --- a/lib/nnc/mfa/v2/AttentionDescriptor.cpp +++ b/lib/nnc/mfa/v2/AttentionDescriptor.cpp @@ -199,7 +199,7 @@ AttentionOperands AttentionDescriptor::createMemoryPrecisi memoryPrecisions[AttentionOperand::Q] = GEMMOperandPrecision::FP16; memoryPrecisions[AttentionOperand::K] = GEMMOperandPrecision::FP16; memoryPrecisions[AttentionOperand::V] = GEMMOperandPrecision::FP16; - memoryPrecisions[AttentionOperand::dO] = GEMMOperandPrecision::FP32; // GEMMOperandPrecision::BF16; + memoryPrecisions[AttentionOperand::dO] = GEMMOperandPrecision::FP16; } else { memoryPrecisions[AttentionOperand::Q] = GEMMOperandPrecision::FP32; memoryPrecisions[AttentionOperand::K] = GEMMOperandPrecision::FP32; @@ -321,8 +321,11 @@ AttentionOperands AttentionDescriptor::createMemoryPrecisi // will always write O as FP32 in memory. This choice simplifies // everything, just like the choice to always store log-sum-exp during the // forward pass. It also removes the concern of rounding error from - // frequently truncating the FP32 numbers to FP16. - memoryPrecisions[AttentionOperand::O] = GEMMOperandPrecision::FP32; + if (type.value != AttentionKernelType::forward && lowPrecisionInputs) { + memoryPrecisions[AttentionOperand::O] = GEMMOperandPrecision::FP16; + } else { + memoryPrecisions[AttentionOperand::O] = GEMMOperandPrecision::FP32; + } memoryPrecisions[AttentionOperand::dV] = GEMMOperandPrecision::FP32; memoryPrecisions[AttentionOperand::dK] = GEMMOperandPrecision::FP32; memoryPrecisions[AttentionOperand::dQ] = GEMMOperandPrecision::FP32; @@ -342,7 +345,7 @@ AttentionOperands AttentionDescriptor::createRegisterPreci registerPrecisions[AttentionOperand::Q] = GEMMOperandPrecision::FP16; registerPrecisions[AttentionOperand::K] = GEMMOperandPrecision::FP16; registerPrecisions[AttentionOperand::V] = GEMMOperandPrecision::FP16; - registerPrecisions[AttentionOperand::dO] = GEMMOperandPrecision::FP32; // hasNativeBF16Casting ? GEMMOperandPrecision::BF16 : GEMMOperandPrecision::FP32; + registerPrecisions[AttentionOperand::dO] = GEMMOperandPrecision::FP16; } else { registerPrecisions[AttentionOperand::Q] = GEMMOperandPrecision::FP32; registerPrecisions[AttentionOperand::K] = GEMMOperandPrecision::FP32; @@ -353,7 +356,7 @@ AttentionOperands AttentionDescriptor::createRegisterPreci // The register precision of L/D only counts for backward key-value. if (lowPrecisionIntermediates) { registerPrecisions[AttentionOperand::L] = GEMMOperandPrecision::FP16; - registerPrecisions[AttentionOperand::D] = GEMMOperandPrecision::FP32; // hasNativeBF16Casting ? GEMMOperandPrecision::BF16 : GEMMOperandPrecision::FP32; + registerPrecisions[AttentionOperand::D] = GEMMOperandPrecision::FP32; } else { registerPrecisions[AttentionOperand::L] = GEMMOperandPrecision::FP32; registerPrecisions[AttentionOperand::D] = GEMMOperandPrecision::FP32; @@ -380,7 +383,7 @@ AttentionOperands AttentionDescriptor::createRegisterPreci registerPrecisions[AttentionOperand::S] = lowPrecisionInputs ? GEMMOperandPrecision::FP16 : GEMMOperandPrecision::FP32; registerPrecisions[AttentionOperand::P] = GEMMOperandPrecision::FP16; registerPrecisions[AttentionOperand::dP] = GEMMOperandPrecision::FP32; - registerPrecisions[AttentionOperand::dS] = GEMMOperandPrecision::FP32; // hasNativeBF16Casting ? GEMMOperandPrecision::BF16 : GEMMOperandPrecision::FP32; + registerPrecisions[AttentionOperand::dS] = GEMMOperandPrecision::FP32; } else { registerPrecisions[AttentionOperand::S] = GEMMOperandPrecision::FP32; registerPrecisions[AttentionOperand::P] = GEMMOperandPrecision::FP32; @@ -389,7 +392,11 @@ AttentionOperands AttentionDescriptor::createRegisterPreci } // All of the outputs are accumulated in FP32. - registerPrecisions[AttentionOperand::O] = GEMMOperandPrecision::FP32; + if (type.value != AttentionKernelType::forward && lowPrecisionInputs) { + registerPrecisions[AttentionOperand::O] = GEMMOperandPrecision::FP16; + } else { + registerPrecisions[AttentionOperand::O] = GEMMOperandPrecision::FP32; + } registerPrecisions[AttentionOperand::dV] = GEMMOperandPrecision::FP32; registerPrecisions[AttentionOperand::dK] = GEMMOperandPrecision::FP32; registerPrecisions[AttentionOperand::dQ] = GEMMOperandPrecision::FP32; diff --git a/test/int/nnc/mpsblas.tests.c b/test/int/nnc/mpsblas.tests.c index 1539f6b83..a3ab3fcd2 100644 --- a/test/int/nnc/mpsblas.tests.c +++ b/test/int/nnc/mpsblas.tests.c @@ -1780,6 +1780,123 @@ TEST_CASE("scaled dot product attention gradient with mps") #undef num_trials } +TEST_CASE("scaled dot product attention gradient with mps in half-precision") +{ + GUARD_ELSE_RETURN(ccv_nnc_cmd_ok(CCV_NNC_SCALED_DOT_PRODUCT_ATTENTION_FORWARD, CCV_NNC_BACKEND_MPS) && + ccv_nnc_cmd_ok(CCV_NNC_SCALED_DOT_PRODUCT_ATTENTION_BACKWARD, CCV_NNC_BACKEND_MPS)); +#define num_long_trials 8 +#define num_short_trials 4 +#define num_trials (num_long_trials + num_short_trials) + + dsfmt_t dsfmt; + dsfmt_init_gen_rand(&dsfmt, 10); + for (int trial = 0; trial < num_trials; ++trial) { + const int B_candidates[num_trials] = { 32, 12, 16, 1, 2, 1, 32, 12, 16, 1, 2, 1 }; + const int R_candidates[num_trials] = { 160, 256, 128, 77, 77, 5, 160, 256, 128, 77, 77, 5 }; + const int C_candidates[num_trials] = { 128, 128, 128, 128, 128, 5, 128, 128, 128, 128, 128, 5 }; + const int Hq_candidates[num_trials] = { 8, 8, 8, 8, 8, 32, 8, 8, 8, 8, 8, 32 }; + const int D_candidates[num_trials] = { 64, 40, 160, 192, 256, 128, 64, 40, 160, 192, 256, 128 }; + + const int B = B_candidates[trial]; + const int R = R_candidates[trial]; + const int C = C_candidates[trial]; + const int Hq = Hq_candidates[trial]; + const int Hk = Hq_candidates[trial]; + const int D = D_candidates[trial]; + const int is_causal = 0; + const float scale = 1.0 / sqrt((float)D); + + ccv_nnc_tensor_t* const q_tensor = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, B, R, Hq, D), 0); + ccv_nnc_tensor_t* const k_tensor = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, B, C, Hk, D), 0); + ccv_nnc_tensor_t* const v_tensor = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, B, C, Hk, D), 0); + ccv_nnc_tensor_t* const dq_tensor = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, B, R, Hq, D), 0); + ccv_nnc_tensor_t* const dk_tensor = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, B, C, Hk, D), 0); + ccv_nnc_tensor_t* const dv_tensor = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, B, C, Hk, D), 0); + + for (int i = 0; i < B * R * Hq * D; ++i) { + q_tensor->data.f32[i] = dsfmt_genrand_open_close(&dsfmt); + } + for (int i = 0; i < B * C * Hk * D; ++i) { + k_tensor->data.f32[i] = dsfmt_genrand_open_close(&dsfmt); + } + for (int i = 0; i < B * C * Hk * D; ++i) { + v_tensor->data.f32[i] = dsfmt_genrand_open_close(&dsfmt); + } + + ccv_nnc_tensor_t* const do_tensor = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, B, R, Hq, D), 0); + for (int i = 0; i < B * R * Hq * D; ++i) { + do_tensor->data.f32[i] = dsfmt_genrand_open_close(&dsfmt); + } + ccv_nnc_cmd_exec(CMD_SCALED_DOT_PRODUCT_ATTENTION_BACKWARD(scale, is_causal), ccv_nnc_no_hint, 0, TENSOR_LIST(do_tensor, 0, 0, q_tensor, k_tensor, v_tensor), TENSOR_LIST(dq_tensor, dk_tensor, dv_tensor), 0); + ccv_nnc_tensor_t* const q_tensor_f16 = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(16F, B, R, Hq, D), 0); + ccv_nnc_tensor_t* const k_tensor_f16 = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(16F, B, C, Hk, D), 0); + ccv_nnc_tensor_t* const v_tensor_f16 = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(16F, B, C, Hk, D), 0); + ccv_nnc_tensor_t* const do_tensor_f16 = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(16F, B, R, Hq, D), 0); + ccv_nnc_cmd_exec(CMD_DATATYPE_CONVERSION_FORWARD(), ccv_nnc_no_hint, 0, TENSOR_LIST(q_tensor, k_tensor, v_tensor, do_tensor), TENSOR_LIST(q_tensor_f16, k_tensor_f16, v_tensor_f16, do_tensor_f16), 0); + + ccv_nnc_tensor_t* const gpu_q_tensor = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 16F, B, R, Hq, D), 0); + ccv_nnc_tensor_t* const gpu_k_tensor = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 16F, B, C, Hk, D), 0); + ccv_nnc_tensor_t* const gpu_v_tensor = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 16F, B, C, Hk, D), 0); + ccv_nnc_tensor_t* const gpu_o_tensor = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 16F, B, R, Hq, D), 0); + ccv_nnc_tensor_t* const gpu_do_tensor = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 16F, B, R, Hq, D), 0); + ccv_nnc_tensor_t* const gpu_dq_tensor = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 16F, B, R, Hq, D), 0); + ccv_nnc_tensor_t* const gpu_dk_tensor = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 16F, B, C, Hk, D), 0); + ccv_nnc_tensor_t* const gpu_dv_tensor = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 16F, B, C, Hk, D), 0); + ccv_nnc_cmd_exec(CMD_DATA_TRANSFER_FORWARD(), ccv_nnc_no_hint, 0, TENSOR_LIST(q_tensor_f16, k_tensor_f16, v_tensor_f16, do_tensor_f16), TENSOR_LIST(gpu_q_tensor, gpu_k_tensor, gpu_v_tensor, gpu_do_tensor), 0); + + ccv_nnc_tensor_t* const gpu_softmax_lse = ccv_nnc_tensor_new(0, GPU_TENSOR_NHWC(000, 32F, B, Hq, R), 0); + ccv_nnc_cmd_exec(CMD_SCALED_DOT_PRODUCT_ATTENTION_FORWARD(scale, is_causal), ccv_nnc_no_hint, 0, TENSOR_LIST(gpu_q_tensor, gpu_k_tensor, gpu_v_tensor, NULL, NULL, NULL), TENSOR_LIST(gpu_o_tensor, gpu_softmax_lse), 0); + + ccv_nnc_cmd_t cmd = CMD_SCALED_DOT_PRODUCT_ATTENTION_BACKWARD(scale, is_causal); + cmd.info.scaled_dot_product_attention.deterministic = 0; + ccv_nnc_cmd_exec(cmd, ccv_nnc_no_hint, 0, TENSOR_LIST(gpu_do_tensor, 0, 0, gpu_q_tensor, gpu_k_tensor, gpu_v_tensor, 0, 0, 0, gpu_o_tensor, gpu_softmax_lse), TENSOR_LIST(gpu_dq_tensor, gpu_dk_tensor, gpu_dv_tensor), 0); + + ccv_nnc_tensor_t* const copy_of_gpu_dq_tensor_f16 = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(16F, B, R, Hq, D), 0); + ccv_nnc_tensor_t* const copy_of_gpu_dk_tensor_f16 = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(16F, B, C, Hk, D), 0); + ccv_nnc_tensor_t* const copy_of_gpu_dv_tensor_f16 = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(16F, B, C, Hk, D), 0); + ccv_nnc_cmd_exec(CMD_DATA_TRANSFER_FORWARD(), ccv_nnc_no_hint, 0, TENSOR_LIST(gpu_dq_tensor, gpu_dk_tensor, gpu_dv_tensor), TENSOR_LIST(copy_of_gpu_dq_tensor_f16, copy_of_gpu_dk_tensor_f16, copy_of_gpu_dv_tensor_f16), 0); + + ccv_nnc_tensor_t* const copy_of_gpu_dq_tensor = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, B, R, Hq, D), 0); + ccv_nnc_tensor_t* const copy_of_gpu_dk_tensor = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, B, C, Hk, D), 0); + ccv_nnc_tensor_t* const copy_of_gpu_dv_tensor = ccv_nnc_tensor_new(0, CPU_TENSOR_NHWC(32F, B, C, Hk, D), 0); + ccv_nnc_cmd_exec(CMD_DATATYPE_CONVERSION_FORWARD(), ccv_nnc_no_hint, 0, TENSOR_LIST(copy_of_gpu_dq_tensor_f16, copy_of_gpu_dk_tensor_f16, copy_of_gpu_dv_tensor_f16), TENSOR_LIST(copy_of_gpu_dq_tensor, copy_of_gpu_dk_tensor, copy_of_gpu_dv_tensor), 0); + + REQUIRE_ARRAY_EQ_WITH_TOLERANCE(float, copy_of_gpu_dq_tensor->data.f32, dq_tensor->data.f32, B * R * Hq * D, 1e-3, "scaled dot product attention result should be the same"); + REQUIRE_ARRAY_EQ_WITH_TOLERANCE(float, copy_of_gpu_dk_tensor->data.f32, dk_tensor->data.f32, B * C * Hk * D, 3e-3, "scaled dot product attention result should be the same"); + REQUIRE_ARRAY_EQ_WITH_TOLERANCE(float, copy_of_gpu_dv_tensor->data.f32, dv_tensor->data.f32, B * C * Hk * D, 6e-3, "GPU computed output should be the same as CPU computed ones"); + + ccv_nnc_tensor_free(do_tensor); + ccv_nnc_tensor_free(gpu_do_tensor); + ccv_nnc_tensor_free(gpu_o_tensor); + ccv_nnc_tensor_free(copy_of_gpu_dq_tensor_f16); + ccv_nnc_tensor_free(copy_of_gpu_dk_tensor_f16); + ccv_nnc_tensor_free(copy_of_gpu_dv_tensor_f16); + ccv_nnc_tensor_free(copy_of_gpu_dq_tensor); + ccv_nnc_tensor_free(copy_of_gpu_dk_tensor); + ccv_nnc_tensor_free(copy_of_gpu_dv_tensor); + ccv_nnc_tensor_free(q_tensor); + ccv_nnc_tensor_free(k_tensor); + ccv_nnc_tensor_free(v_tensor); + ccv_nnc_tensor_free(q_tensor_f16); + ccv_nnc_tensor_free(k_tensor_f16); + ccv_nnc_tensor_free(v_tensor_f16); + ccv_nnc_tensor_free(do_tensor_f16); + ccv_nnc_tensor_free(gpu_q_tensor); + ccv_nnc_tensor_free(gpu_k_tensor); + ccv_nnc_tensor_free(gpu_v_tensor); + ccv_nnc_tensor_free(dq_tensor); + ccv_nnc_tensor_free(dk_tensor); + ccv_nnc_tensor_free(dv_tensor); + ccv_nnc_tensor_free(gpu_dq_tensor); + ccv_nnc_tensor_free(gpu_dk_tensor); + ccv_nnc_tensor_free(gpu_dv_tensor); + ccv_nnc_tensor_free(gpu_softmax_lse); + } +#undef num_long_trials +#undef num_short_trials +#undef num_trials +} + TEST_CASE("backward gemm with no transpose") { GUARD_ELSE_RETURN(ccv_nnc_cmd_ok(CCV_NNC_GEMM_FORWARD, CCV_NNC_BACKEND_MPS) &&