Skip to content

Commit

Permalink
Fix various issues related to running backprop of sdpa on fp16.
Browse files Browse the repository at this point in the history
  • Loading branch information
liuliu committed Oct 30, 2024
1 parent 18f1860 commit fe99c48
Show file tree
Hide file tree
Showing 3 changed files with 160 additions and 33 deletions.
55 changes: 29 additions & 26 deletions lib/nnc/mfa/ccv_nnc_mfa_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ void ccv_nnc_mfa_encode_attention(mfa::context* context, ccv_nnc_mfa_attention_p
if (tensors[5]) {
encoder->useResource(tensors[5], MTL::ResourceUsageRead | MTL::ResourceUsageWrite);
}
encoder->setBuffer(tensors[0], tensor_offsets[0], AttentionOperand(AttentionOperand::Q).bufferIndex());
encoder->setBuffer(tensors[0], tensor_offsets[0], AttentionOperand(AttentionOperand::Q).bufferIndex());
encoder->setBuffer(tensors[1], tensor_offsets[1], AttentionOperand(AttentionOperand::K).bufferIndex());
encoder->setBuffer(tensors[2], tensor_offsets[2], AttentionOperand(AttentionOperand::V).bufferIndex());
if (attentionDesc.lowPrecisionInputs) {
Expand Down Expand Up @@ -198,6 +198,15 @@ void ccv_nnc_mfa_encode_attention(mfa::context* context, ccv_nnc_mfa_attention_p
auto backwardKeyValueKernel = backwardKeyValuePipelineValue->kernel;
auto backwardKeyValuePipeline = backwardKeyValuePipelineValue->pipeline;

auto scratch_size = 0;
if (attentionDesc.lowPrecisionInputs) {
// Need scratch space for FP16 output.
scratch_size += sizeof(float) * (hash.R + hash.C * 2) * hash.D * hash.Hq * attentionDesc.batchDimension;
}
// Need scratch space for D.
scratch_size += sizeof(float) * hash.R * hash.Hq * attentionDesc.batchDimension;
auto scratch = context->request_scratch(scratch_size);

// Allocate a new command.
auto backwardQueryEncoder = command_batch->startCommand();
backwardQueryEncoder->setComputePipelineState(backwardQueryPipeline.get());
Expand All @@ -210,14 +219,6 @@ void ccv_nnc_mfa_encode_attention(mfa::context* context, ccv_nnc_mfa_attention_p
backwardQueryEncoder->useResource(tensors[3], MTL::ResourceUsageRead);
backwardQueryEncoder->useResource(tensors[4], MTL::ResourceUsageRead);
backwardQueryEncoder->useResource(tensors[5], MTL::ResourceUsageRead);
auto scratch_size = 0;
if (attentionDesc.lowPrecisionInputs) {
// Need scratch space for FP16 output.
scratch_size += sizeof(float) * hash.R * hash.D * hash.Hq * attentionDesc.batchDimension * 3;
}
// Need scratch space for D.
scratch_size += sizeof(float) * hash.R * hash.Hq * attentionDesc.batchDimension;
auto scratch = context->request_scratch(scratch_size);
backwardQueryEncoder->useResource(scratch, MTL::ResourceUsageRead | MTL::ResourceUsageWrite);
if (!attentionDesc.lowPrecisionInputs) {
backwardQueryEncoder->useResource(tensors[6], MTL::ResourceUsageWrite);
Expand All @@ -231,7 +232,7 @@ void ccv_nnc_mfa_encode_attention(mfa::context* context, ccv_nnc_mfa_attention_p
backwardQueryEncoder->setBuffer(tensors[5], tensor_offsets[5], AttentionOperand(AttentionOperand::dO).bufferIndex());
if (attentionDesc.lowPrecisionInputs) {
backwardQueryEncoder->setBuffer(scratch, 0, AttentionOperand(AttentionOperand::dQ).bufferIndex());
backwardQueryEncoder->setBuffer(scratch, sizeof(float) * hash.R * hash.D * hash.Hq * attentionDesc.batchDimension * 3, AttentionOperand(AttentionOperand::D).bufferIndex());
backwardQueryEncoder->setBuffer(scratch, sizeof(float) * (hash.R + hash.C * 2) * hash.D * hash.Hq * attentionDesc.batchDimension, AttentionOperand(AttentionOperand::D).bufferIndex());
} else {
backwardQueryEncoder->setBuffer(tensors[6], tensor_offsets[6], AttentionOperand(AttentionOperand::dQ).bufferIndex());
backwardQueryEncoder->setBuffer(scratch, 0, AttentionOperand(AttentionOperand::D).bufferIndex());
Expand Down Expand Up @@ -268,24 +269,24 @@ void ccv_nnc_mfa_encode_attention(mfa::context* context, ccv_nnc_mfa_attention_p
backwardKeyValueEncoder->useResource(tensors[8], MTL::ResourceUsageWrite);
}

backwardQueryEncoder->setBuffer(tensors[0], tensor_offsets[0], AttentionOperand(AttentionOperand::Q).bufferIndex());
backwardQueryEncoder->setBuffer(tensors[1], tensor_offsets[1], AttentionOperand(AttentionOperand::K).bufferIndex());
backwardQueryEncoder->setBuffer(tensors[2], tensor_offsets[2], AttentionOperand(AttentionOperand::V).bufferIndex());
backwardQueryEncoder->setBuffer(tensors[3], tensor_offsets[3], AttentionOperand(AttentionOperand::O).bufferIndex());
backwardQueryEncoder->setBuffer(tensors[4], tensor_offsets[4], AttentionOperand(AttentionOperand::L).bufferIndex());
backwardQueryEncoder->setBuffer(tensors[5], tensor_offsets[5], AttentionOperand(AttentionOperand::dO).bufferIndex());
backwardKeyValueEncoder->setBuffer(tensors[0], tensor_offsets[0], AttentionOperand(AttentionOperand::Q).bufferIndex());
backwardKeyValueEncoder->setBuffer(tensors[1], tensor_offsets[1], AttentionOperand(AttentionOperand::K).bufferIndex());
backwardKeyValueEncoder->setBuffer(tensors[2], tensor_offsets[2], AttentionOperand(AttentionOperand::V).bufferIndex());
backwardKeyValueEncoder->setBuffer(tensors[3], tensor_offsets[3], AttentionOperand(AttentionOperand::O).bufferIndex());
backwardKeyValueEncoder->setBuffer(tensors[4], tensor_offsets[4], AttentionOperand(AttentionOperand::L).bufferIndex());
backwardKeyValueEncoder->setBuffer(tensors[5], tensor_offsets[5], AttentionOperand(AttentionOperand::dO).bufferIndex());
if (attentionDesc.lowPrecisionInputs) {
backwardQueryEncoder->setBuffer(scratch, sizeof(float) * hash.R * hash.D * hash.Hq * attentionDesc.batchDimension, AttentionOperand(AttentionOperand::dK).bufferIndex());
backwardQueryEncoder->setBuffer(scratch, sizeof(float) * hash.R * hash.D * hash.Hq * attentionDesc.batchDimension * 2, AttentionOperand(AttentionOperand::dV).bufferIndex());
backwardQueryEncoder->setBuffer(scratch, sizeof(float) * hash.R * hash.D * hash.Hq * attentionDesc.batchDimension * 3, AttentionOperand(AttentionOperand::D).bufferIndex());
backwardKeyValueEncoder->setBuffer(scratch, sizeof(float) * hash.R * hash.D * hash.Hq * attentionDesc.batchDimension, AttentionOperand(AttentionOperand::dK).bufferIndex());
backwardKeyValueEncoder->setBuffer(scratch, sizeof(float) * (hash.R + hash.C) * hash.D * hash.Hq * attentionDesc.batchDimension, AttentionOperand(AttentionOperand::dV).bufferIndex());
backwardKeyValueEncoder->setBuffer(scratch, sizeof(float) * (hash.R + hash.C * 2) * hash.D * hash.Hq * attentionDesc.batchDimension, AttentionOperand(AttentionOperand::D).bufferIndex());
} else {
backwardQueryEncoder->setBuffer(scratch, 0, AttentionOperand(AttentionOperand::D).bufferIndex());
backwardQueryEncoder->setBuffer(tensors[7], tensor_offsets[7], AttentionOperand(AttentionOperand::dK).bufferIndex());
backwardQueryEncoder->setBuffer(tensors[8], tensor_offsets[8], AttentionOperand(AttentionOperand::dV).bufferIndex());
backwardKeyValueEncoder->setBuffer(scratch, 0, AttentionOperand(AttentionOperand::D).bufferIndex());
backwardKeyValueEncoder->setBuffer(tensors[7], tensor_offsets[7], AttentionOperand(AttentionOperand::dK).bufferIndex());
backwardKeyValueEncoder->setBuffer(tensors[8], tensor_offsets[8], AttentionOperand(AttentionOperand::dV).bufferIndex());
}

MTL::Size backwardKeyValueGridSize
(ceilDivide(int64_t(hash.R), backwardKeyValueKernel->blockDimensions[0]),
(ceilDivide(int64_t(hash.C), backwardKeyValueKernel->blockDimensions[0]),
hash.Hq,
attentionDesc.batchDimension);
MTL::Size backwardKeyValueGroupSize
Expand Down Expand Up @@ -315,12 +316,14 @@ void ccv_nnc_mfa_encode_attention(mfa::context* context, ccv_nnc_mfa_attention_p
tensor_offsets[6]
};
ccv_nnc_mfa_encode_cast(context, cast_params, command_batch, cast_tensors, cast_tensor_offsets);
cast_tensors[1] = tensors[7];
cast_params.length = hash.C * hash.D * hash.Hq * attentionDesc.batchDimension;
ccv_nnc_mfa_prepare_cast(context, cast_params);
cast_tensors[1] = tensors[7];
cast_tensor_offsets[0] = sizeof(float) * hash.R * hash.D * hash.Hq * attentionDesc.batchDimension;
cast_tensor_offsets[1] = tensor_offsets[7];
ccv_nnc_mfa_encode_cast(context, cast_params, command_batch, cast_tensors, cast_tensor_offsets);
cast_tensors[1] = tensors[8];
cast_tensor_offsets[0] = sizeof(float) * hash.R * hash.D * hash.Hq * attentionDesc.batchDimension * 2;
cast_tensors[1] = tensors[8];
cast_tensor_offsets[0] = sizeof(float) * (hash.R + hash.C) * hash.D * hash.Hq * attentionDesc.batchDimension;
cast_tensor_offsets[1] = tensor_offsets[8];
ccv_nnc_mfa_encode_cast(context, cast_params, command_batch, cast_tensors, cast_tensor_offsets);
}
Expand Down
21 changes: 14 additions & 7 deletions lib/nnc/mfa/v2/AttentionDescriptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ AttentionOperands<GEMMOperandPrecision> AttentionDescriptor::createMemoryPrecisi
memoryPrecisions[AttentionOperand::Q] = GEMMOperandPrecision::FP16;
memoryPrecisions[AttentionOperand::K] = GEMMOperandPrecision::FP16;
memoryPrecisions[AttentionOperand::V] = GEMMOperandPrecision::FP16;
memoryPrecisions[AttentionOperand::dO] = GEMMOperandPrecision::FP32; // GEMMOperandPrecision::BF16;
memoryPrecisions[AttentionOperand::dO] = GEMMOperandPrecision::FP16;
} else {
memoryPrecisions[AttentionOperand::Q] = GEMMOperandPrecision::FP32;
memoryPrecisions[AttentionOperand::K] = GEMMOperandPrecision::FP32;
Expand Down Expand Up @@ -321,8 +321,11 @@ AttentionOperands<GEMMOperandPrecision> AttentionDescriptor::createMemoryPrecisi
// will always write O as FP32 in memory. This choice simplifies
// everything, just like the choice to always store log-sum-exp during the
// forward pass. It also removes the concern of rounding error from
// frequently truncating the FP32 numbers to FP16.
memoryPrecisions[AttentionOperand::O] = GEMMOperandPrecision::FP32;
if (type.value != AttentionKernelType::forward && lowPrecisionInputs) {
memoryPrecisions[AttentionOperand::O] = GEMMOperandPrecision::FP16;
} else {
memoryPrecisions[AttentionOperand::O] = GEMMOperandPrecision::FP32;
}
memoryPrecisions[AttentionOperand::dV] = GEMMOperandPrecision::FP32;
memoryPrecisions[AttentionOperand::dK] = GEMMOperandPrecision::FP32;
memoryPrecisions[AttentionOperand::dQ] = GEMMOperandPrecision::FP32;
Expand All @@ -342,7 +345,7 @@ AttentionOperands<GEMMOperandPrecision> AttentionDescriptor::createRegisterPreci
registerPrecisions[AttentionOperand::Q] = GEMMOperandPrecision::FP16;
registerPrecisions[AttentionOperand::K] = GEMMOperandPrecision::FP16;
registerPrecisions[AttentionOperand::V] = GEMMOperandPrecision::FP16;
registerPrecisions[AttentionOperand::dO] = GEMMOperandPrecision::FP32; // hasNativeBF16Casting ? GEMMOperandPrecision::BF16 : GEMMOperandPrecision::FP32;
registerPrecisions[AttentionOperand::dO] = GEMMOperandPrecision::FP16;
} else {
registerPrecisions[AttentionOperand::Q] = GEMMOperandPrecision::FP32;
registerPrecisions[AttentionOperand::K] = GEMMOperandPrecision::FP32;
Expand All @@ -353,7 +356,7 @@ AttentionOperands<GEMMOperandPrecision> AttentionDescriptor::createRegisterPreci
// The register precision of L/D only counts for backward key-value.
if (lowPrecisionIntermediates) {
registerPrecisions[AttentionOperand::L] = GEMMOperandPrecision::FP16;
registerPrecisions[AttentionOperand::D] = GEMMOperandPrecision::FP32; // hasNativeBF16Casting ? GEMMOperandPrecision::BF16 : GEMMOperandPrecision::FP32;
registerPrecisions[AttentionOperand::D] = GEMMOperandPrecision::FP32;
} else {
registerPrecisions[AttentionOperand::L] = GEMMOperandPrecision::FP32;
registerPrecisions[AttentionOperand::D] = GEMMOperandPrecision::FP32;
Expand All @@ -380,7 +383,7 @@ AttentionOperands<GEMMOperandPrecision> AttentionDescriptor::createRegisterPreci
registerPrecisions[AttentionOperand::S] = lowPrecisionInputs ? GEMMOperandPrecision::FP16 : GEMMOperandPrecision::FP32;
registerPrecisions[AttentionOperand::P] = GEMMOperandPrecision::FP16;
registerPrecisions[AttentionOperand::dP] = GEMMOperandPrecision::FP32;
registerPrecisions[AttentionOperand::dS] = GEMMOperandPrecision::FP32; // hasNativeBF16Casting ? GEMMOperandPrecision::BF16 : GEMMOperandPrecision::FP32;
registerPrecisions[AttentionOperand::dS] = GEMMOperandPrecision::FP32;
} else {
registerPrecisions[AttentionOperand::S] = GEMMOperandPrecision::FP32;
registerPrecisions[AttentionOperand::P] = GEMMOperandPrecision::FP32;
Expand All @@ -389,7 +392,11 @@ AttentionOperands<GEMMOperandPrecision> AttentionDescriptor::createRegisterPreci
}

// All of the outputs are accumulated in FP32.
registerPrecisions[AttentionOperand::O] = GEMMOperandPrecision::FP32;
if (type.value != AttentionKernelType::forward && lowPrecisionInputs) {
registerPrecisions[AttentionOperand::O] = GEMMOperandPrecision::FP16;
} else {
registerPrecisions[AttentionOperand::O] = GEMMOperandPrecision::FP32;
}
registerPrecisions[AttentionOperand::dV] = GEMMOperandPrecision::FP32;
registerPrecisions[AttentionOperand::dK] = GEMMOperandPrecision::FP32;
registerPrecisions[AttentionOperand::dQ] = GEMMOperandPrecision::FP32;
Expand Down
Loading

0 comments on commit fe99c48

Please sign in to comment.