Skip to content

Commit

Permalink
Code as it should be, but doesnt work
Browse files Browse the repository at this point in the history
  • Loading branch information
sushraja-msft committed Nov 22, 2024
1 parent c281f84 commit befc797
Showing 1 changed file with 46 additions and 63 deletions.
109 changes: 46 additions & 63 deletions onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -569,44 +569,36 @@ fn debugKTile() -> q_value_t
return sum_value;
}
fn loadk(slot: u32, k_idx_global : u32, head_idx: u32)
fn loadk(slot: u32, k_idx_global : u32, head_idx: u32, sg_id: u32, sg_size: u32)
{
if (k_idx_global >= uniforms.present_sequence_length) {
return;
}
// Stored as float16[batch_size,num_heads,present_sequence_length,96]
let offset = head_idx * uniforms.present_sequence_length * QKV_HEAD_VECTORIZED_SIZE + k_idx_global * QKV_HEAD_VECTORIZED_SIZE;
for (var idx:u32 = 0; idx < QKV_HEAD_VECTORIZED_SIZE; idx++)
for (var idx:u32 = sg_id; idx < QKV_HEAD_VECTORIZED_SIZE; idx+=sg_size)
{
var value = present_key[idx+offset];
k_tile[slot][idx] = value;
}
}
fn loadv(slot: u32, v_idx_global : u32, head_idx: u32)
fn loadv(slot: u32, v_idx_global : u32, head_idx: u32, sg_id: u32, sg_size: u32)
{
if (v_idx_global >= uniforms.present_sequence_length) {
return;
}
// Stored as float16[batch_size,num_heads,present_sequence_length,96]
let offset = head_idx * uniforms.present_sequence_length * QKV_HEAD_VECTORIZED_SIZE + v_idx_global * QKV_HEAD_VECTORIZED_SIZE;
for (var idx:u32 = 0; idx < QKV_HEAD_VECTORIZED_SIZE; idx ++)
for (var idx:u32 = sg_id; idx < QKV_HEAD_VECTORIZED_SIZE; idx+=sg_size)
{
v_tile[slot][idx] = present_value[idx+offset];
}
}
fn loadAttentionBias(qtile_row: u32, q_idx_global : u32, k_col: u32, k_idx_global : u32, head_idx: u32)
fn loadAttentionBias(q_row: u32, q_idx_global : u32, k_col: u32, k_idx_global : u32, head_idx: u32)
{
// Stored as float16[batch_size,num_heads,new_seq_length,total_sequence_length]
if (q_idx_global >= uniforms.new_sequence_length || k_idx_global >= uniforms.present_sequence_length) {
qk_tile[qtile_row][k_col] = 0.0;
qk_tile[q_row][k_col] = 0.0;
return;
}
let offset = head_idx * uniforms.new_sequence_length * uniforms.present_sequence_length + q_idx_global * uniforms.present_sequence_length + k_idx_global;
qk_tile[qtile_row][k_col] = attention_bias[offset];
qk_tile[q_row][k_col] = attention_bias[offset];
}
fn writeo(slot: u32, o_idx_global : u32, head_idx: u32, sg_id : u32, sg_size : u32)
Expand Down Expand Up @@ -705,37 +697,54 @@ fn computeO(q_idx: u32, sg_id:u32, enabled:bool)
)HELPER_FN";

// Shader is designed to be dispatched as Dispatch(num_heads, present_seq_length / TILE_SIZE, 1)
// QKV_HEAD_VECTORIZED_SIZE % sg_id == 0 for loadq, loadk and computeDotProduct to work right.

shader.MainFunctionBody() << R"MAIN_FN(
let head_idx = workgroup_id.x;
// Split the composite workgroup id into actual y and subgroup id.
let q_tile_row = u32(local_idx / sg_size);
let q_idx_global = workgroup_id.y * TILE_SIZE + q_tile_row;
// Each invocation (q_tile_row) gets x threads (subgroup threads) and is responsible for 1 query.
loadq(q_tile_row, q_idx_global, head_idx, sg_id, sg_size);
max_tile[sg_id] = MIN_VALUE;
let wave_x = (local_id.x / 4);
let wave_y = (local_id.y / 4);
let wave_id = wave_x + wave_y * 4;
let q_idx_start = workgroup_id.y * TILE_SIZE;
let q_idx_global = q_idx_start + wave_id;
// Each invocation (wave_id) gets lane threads (subgroup threads) and is responsible for 1 query.
loadq(wave_id, q_idx_global, head_idx, sg_id, sg_size);
if (sg_id == 0)
{
max_tile[wave_id] = MIN_VALUE;
}
for(var k_start = 0u; k_start < uniforms.present_sequence_length; k_start+=TILE_SIZE)
{
if (sg_id < TILE_SIZE && k_start+sg_id < uniforms.present_sequence_length) {
loadk(sg_id, k_start+sg_id, head_idx);
loadv(sg_id, k_start+sg_id, head_idx);
loadAttentionBias(q_tile_row, q_idx_global, sg_id, k_start+sg_id, head_idx);
let k_idx_global = k_start+wave_id;
let k_idx_global_using_wave_valid = k_idx_global < uniforms.present_sequence_length;
if (k_idx_global_using_wave_valid) {
// Leveraging the subgroup lanes for parallelism, load into slot wave_id
// K/V values from k_start+wave_id.
loadk(wave_id, k_idx_global, head_idx, sg_id, sg_size);
loadv(wave_id, k_idx_global, head_idx, sg_id, sg_size);
// Next, we want for every q row (wave_id) to populate bias for new sequence length
// (k_start+sg_id). loadAttentionBias handles range checking q_idx_global,
// and (k_start+sg_id).
loadAttentionBias(wave_id, q_idx_global, sg_id, k_start+sg_id, head_idx);
}
workgroupBarrier();
// Do k_idx + k_start <= q_idx_global if we want only look past.
for (var k_idx = 0u; k_idx < TILE_SIZE && k_idx + k_start < uniforms.present_sequence_length; k_idx++)
if (k_idx_global_using_wave_valid)
{
computeDotProduct(q_tile_row, k_idx, sg_id, sg_size);
for (var q_idx = 0u; q_idx < TILE_SIZE && q_idx_start + q_idx < uniforms.present_sequence_length; q_idx++)
{
// Leveraging the subgroups for parallelism, compute dot product of QK.
// Because for the case of new_seq 1, there is a single query and context length of K
// we iterate over q and use the waves for K so that this step can use all the waves in
// in the workgroup.
computeDotProduct(q_idx, wave_id, sg_id, sg_size);
}
}
let enabled:bool = sg_id < TILE_SIZE && sg_id + k_start < uniforms.present_sequence_length;
computeSoftMax(q_tile_row, sg_id, enabled);
computeO(q_tile_row, sg_id, enabled);
let k_idx_global_using_lane_valid:bool = sg_id < TILE_SIZE && sg_id + k_start < uniforms.present_sequence_length;
computeSoftMax(wave_id, sg_id, k_idx_global_using_lane_valid);
computeO(wave_id, sg_id, k_idx_global_using_lane_valid);
}
workgroupBarrier();
writeo(q_tile_row, q_idx_global, head_idx, sg_id, sg_size);
writeo(wave_id, q_idx_global, head_idx, sg_id, sg_size);
)MAIN_FN";

Expand All @@ -747,33 +756,8 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co
AttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context) {
ORT_RETURN_IF_ERROR(CopyKVCache(context, parameters, K, past_key, present_key, V, past_value, present_value, parameters.past_sequence_length, parameters.total_sequence_length));

// // Uncomment to test CopyKVCache independent of FlashAttentionProgram.
// TensorShapeVector q_new_dims({parameters.batch_size, parameters.num_heads,
// parameters.sequence_length, parameters.head_size});
// TensorShape q_new_shape(q_new_dims);
// Tensor Qn = context.CreateGPUTensor(Q->DataType(), q_new_shape);
// ORT_RETURN_IF_ERROR(TransferBSDToBNSH(
// context, parameters.num_heads, parameters.sequence_length, parameters.head_size, Q, nullptr, 0, &Qn));

// TensorShapeVector k_new_dims({parameters.batch_size, parameters.num_heads,
// parameters.kv_sequence_length, parameters.head_size});
// TensorShape k_new_shape(k_new_dims);
// Tensor Kn = context.CreateGPUTensor(K->DataType(), k_new_shape);
// ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.num_heads, parameters.kv_sequence_length,
// parameters.head_size, K, nullptr, parameters.hidden_size, &Kn));

// TensorShapeVector v_new_dims({parameters.batch_size, parameters.num_heads,
// parameters.kv_sequence_length, parameters.v_head_size});
// TensorShape v_new_shape(v_new_dims);
// Tensor Vn = context.CreateGPUTensor(V->DataType(), v_new_shape);
// ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.num_heads, parameters.kv_sequence_length,
// parameters.v_head_size, V, nullptr, 2 * parameters.hidden_size, &Vn));

// return ApplyAttention(&Qn, &Kn, &Vn, attention_bias, past_key, past_value, output, present_key,
// present_value, parameters, context, true);

constexpr int subgroup_size = 8;
constexpr int tile_size = 8;
constexpr int subgroup_size = 16;
constexpr int tile_size = 16;
bool has_attention_bias = attention_bias != nullptr;
FlashAttentionProgram program{"FlashAttention", has_attention_bias, subgroup_size, tile_size, parameters.head_size, parameters.num_heads};
program.AddInputs({{Q, ProgramTensorMetadataDependency::TypeAndRank, 4},
Expand All @@ -789,7 +773,7 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co
std::to_string(parameters.head_size) +
std::to_string(parameters.num_heads);
program.SetDispatchGroupSize(parameters.num_heads, (parameters.sequence_length + tile_size - 1) / tile_size, 1)
.SetWorkgroupSize(subgroup_size*tile_size)
.SetWorkgroupSize(subgroup_size, subgroup_size)
.CacheHint(cache_hint)
.AddUniformVariables({{static_cast<uint32_t>(parameters.sequence_length)},
{static_cast<uint32_t>(parameters.total_sequence_length)},
Expand Down Expand Up @@ -854,7 +838,6 @@ Status MultiHeadAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext&

if (parameters.batch_size == 1 &&
bias == nullptr &&
past_key != nullptr && past_value != nullptr && past_key->SizeInBytes() > 0 && past_value->SizeInBytes() > 0 &&
present_key != nullptr && present_value != nullptr && present_key->SizeInBytes() > 0 &&
present_value->SizeInBytes() > 0 && parameters.head_size % 4 == 0) {
return ApplyFlashAttention(query, key, value, attention_bias, output, past_key, present_key, past_value,
Expand Down

0 comments on commit befc797

Please sign in to comment.