Skip to content

Commit

Permalink
Variable renames
Browse files Browse the repository at this point in the history
  • Loading branch information
sushraja-msft committed Nov 21, 2024
1 parent c281f84 commit de32d1f
Showing 1 changed file with 13 additions and 26 deletions.
39 changes: 13 additions & 26 deletions onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -555,20 +555,6 @@ fn loadq(slot: u32, q_idx_global : u32, head_idx: u32, sg_id : u32, sg_size : u3
}
}
fn debugKTile() -> q_value_t
{
var sum_value = q_value_t(0);
for (var qidx:u32 = 0; qidx < TILE_SIZE; qidx++)
{
for (var idx:u32 = 0; idx < QKV_HEAD_VECTORIZED_SIZE; idx++)
{
var value = k_tile[qidx][idx];
sum_value += value;
}
}
return sum_value;
}
fn loadk(slot: u32, k_idx_global : u32, head_idx: u32)
{
if (k_idx_global >= uniforms.present_sequence_length) {
Expand Down Expand Up @@ -709,33 +695,34 @@ fn computeO(q_idx: u32, sg_id:u32, enabled:bool)

shader.MainFunctionBody() << R"MAIN_FN(
let head_idx = workgroup_id.x;
// Split the composite workgroup id into actual y and subgroup id.
let q_tile_row = u32(local_idx / sg_size);
// Split the composite workgroup id into subgroup_cluster_id and subgroup_id.
let subgroup_cluster_id = u32(local_idx / sg_size);
let q_idx_global = workgroup_id.y * TILE_SIZE + q_tile_row;
// Each invocation (q_tile_row) gets x threads (subgroup threads) and is responsible for 1 query.
loadq(q_tile_row, q_idx_global, head_idx, sg_id, sg_size);
let q_idx_start = workgroup_id.y * TILE_SIZE;
let q_idx_global = q_idx_start + subgroup_cluster_id;
// Each invocation (subgroup_cluster_id) gets x threads (subgroup threads) and is responsible for 1 query.
loadq(subgroup_cluster_id, q_idx_global, head_idx, sg_id, sg_size);
max_tile[sg_id] = MIN_VALUE;
for(var k_start = 0u; k_start < uniforms.present_sequence_length; k_start+=TILE_SIZE)
{
if (sg_id < TILE_SIZE && k_start+sg_id < uniforms.present_sequence_length) {
let enabled:bool = sg_id < TILE_SIZE && sg_id + k_start < uniforms.present_sequence_length;
if (enabled) {
loadk(sg_id, k_start+sg_id, head_idx);
loadv(sg_id, k_start+sg_id, head_idx);
loadAttentionBias(q_tile_row, q_idx_global, sg_id, k_start+sg_id, head_idx);
loadAttentionBias(subgroup_cluster_id, q_idx_global, sg_id, k_start+sg_id, head_idx);
}
workgroupBarrier();
// Do k_idx + k_start <= q_idx_global if we want only look past.
for (var k_idx = 0u; k_idx < TILE_SIZE && k_idx + k_start < uniforms.present_sequence_length; k_idx++)
{
computeDotProduct(q_tile_row, k_idx, sg_id, sg_size);
computeDotProduct(subgroup_cluster_id, k_idx, sg_id, sg_size);
}
let enabled:bool = sg_id < TILE_SIZE && sg_id + k_start < uniforms.present_sequence_length;
computeSoftMax(q_tile_row, sg_id, enabled);
computeO(q_tile_row, sg_id, enabled);
computeSoftMax(subgroup_cluster_id, sg_id, enabled);
computeO(subgroup_cluster_id, sg_id, enabled);
}
workgroupBarrier();
writeo(q_tile_row, q_idx_global, head_idx, sg_id, sg_size);
writeo(subgroup_cluster_id, q_idx_global, head_idx, sg_id, sg_size);
)MAIN_FN";

Expand Down

0 comments on commit de32d1f

Please sign in to comment.