Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[js/webgpu] support FlashAttention-2 for attention operator #22915

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
191 changes: 179 additions & 12 deletions js/web/lib/wasm/jsep/webgpu/ops/attention.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1004,21 +1004,188 @@ const prepare = (context: ComputeContext, parameters: AttentionParameters) => {
);
};

// The algorithm refers the FlashAttention forward pass in the https://tridao.me/publications/flash2/flash2.pdf
const createFlashAttentionV2ProgramInfo = (
q: TensorView,
k: TensorView,
v: TensorView,
parameters: AttentionParameters,
attributes: AttentionAttrs,
) => {
const components = 4;
const bR = 32;
const bC = bR;
const tR = q.dims[2] / bR;
const tC = k.dims[2] / bC;
const d = q.dims[3] / components;
const numTiles = Math.ceil(q.dims[3] / bC);
const workgroupSize: [number, number, number] = [8, 32, 1];
const qInner = numTiles * workgroupSize[0];
const colsPerThread = 4; // (Bc / workgroupSize[0])
const alpha = attributes.scale === 0 ? 1.0 / Math.sqrt(parameters.headSize) : attributes.scale;

const dispatch = {
x: tR,
y: v.dims[1],
z: v.dims[0],
};

const headOffset = v.dims[2] * v.dims[3];
const batchOffset = v.dims[1] * headOffset;
const programUniforms: ProgramUniform[] = [
{ type: DataType.uint32, data: batchOffset },
{ type: DataType.uint32, data: headOffset },
{ type: DataType.uint32, data: v.dims[1] },
{ type: DataType.uint32, data: v.dims[3] },
{ type: DataType.uint32, data: d },
{ type: DataType.float, data: alpha },
];

const outputDims = [v.dims[0], v.dims[2], v.dims[1] * v.dims[3]];
const outputs = [{ dims: outputDims, dataType: q.dataType, gpuDataType: GpuDataType.default }];
const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type', 'type'];

const getShaderSource = (shaderHelper: ShaderHelper) => {
const qInput = inputVariable('Q', q.dataType, q.dims, components);
const kInput = inputVariable('K', k.dataType, k.dims, components);
const vInput = inputVariable('V', k.dataType, k.dims, components);
const inputVars = [qInput, kInput, vInput];

const output = outputVariable('output', q.dataType, outputDims);
const outputVars = [output];
const type = tensorTypeToWsglValueType(v.dataType);

const uniforms: UniformsArrayType = [
{ name: 'batchOffset', type: 'u32' },
{ name: 'headOffset', type: 'u32' },
{ name: 'headNum', type: 'u32' },
{ name: 'headSize', type: 'u32' },
{ name: 'd', type: 'u32' },
{ name: 'alpha', type: 'f32' as UniformDataElementType },
];

return `
var<workgroup> Q_i : array<array<${qInput.type.storage}, ${qInner}>, ${bR}>;
var<workgroup> KV_j : array<array<${qInput.type.storage}, ${workgroupSize[0]}>, ${bC}>;
var<workgroup> S_i_j : array<array<${output.type.storage}, ${bC}>, ${bR}>;
${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVars, ...outputVars)}
${shaderHelper.mainStart(workgroupSize)}
let offset = (workgroup_id.z * uniforms.batchOffset + workgroup_id.y * uniforms.headOffset) / ${components} + workgroup_id.x * ${bR} * uniforms.d;
var O_i : array<${type}, ${numTiles * colsPerThread}>;
var l_i_j = ${type}(0);
var m_i_j = ${type === 'f32' ? 'f32(-3.402823e+38f)' : 'f16(-65504)'};
for (var tile = 0; tile < ${numTiles}; tile++) {
Q_i[local_id.y][u32(${workgroupSize[0]} * tile) + local_id.x] = Q[offset + local_id.y * uniforms.d + u32(${workgroupSize[0]} * tile) + local_id.x];
}
for (var j = 0; j < ${tC}; j++) {
var acc : array<${type}, ${colsPerThread}>;
let kvOffset = (workgroup_id.z * uniforms.batchOffset + workgroup_id.y * uniforms.headOffset) / ${components} + u32(j * ${bC}) * uniforms.d;
for (var tile = 0; tile < ${numTiles}; tile++) {
KV_j[local_id.y][local_id.x] = K[kvOffset + local_id.y * uniforms.d + local_id.x + u32(tile * 8)];
workgroupBarrier();
for (var col = 0; col < ${colsPerThread}; col++) {
for (var k = 0; k < ${workgroupSize[0]}; k++) {
acc[col] += dot(Q_i[local_id.y][k + tile * 8], KV_j[local_id.x + u32(col * 8)][k]);
}
}
workgroupBarrier();
}
for (var col = 0; col < ${colsPerThread}; col++) {
S_i_j[local_id.y][u32(col * 8) + local_id.x] = acc[col] * ${type}(uniforms.alpha);
}
workgroupBarrier();
let m_i_j_1 = m_i_j;
for (var m = 0; m < ${bC}; m++) {
m_i_j = max(m_i_j, S_i_j[local_id.y][m]);
}
let exp_j_j_1 = exp(m_i_j_1 - m_i_j);
l_i_j *= exp_j_j_1;
for (var o = 0; o < ${colsPerThread * numTiles}; o++) {
O_i[o] *= exp_j_j_1;
}
for (var tile = 0; tile < ${numTiles}; tile++) {
KV_j[local_id.y][local_id.x] = V[kvOffset + local_id.y * uniforms.d + local_id.x + u32(tile * 8)];
workgroupBarrier();
for (var d = 0; d < ${bC}; d++) {
let p_i_j = exp(S_i_j[local_id.y][d] - m_i_j);
if (tile == 0) {
l_i_j += p_i_j;
}
for (var col = 0; col < ${colsPerThread}; col++) {
let v_i_j = KV_j[d][(u32(8 * col) + local_id.x) / 4][(u32(8 * col) + local_id.x) % 4];
O_i[col * ${numTiles} + tile] += p_i_j * v_i_j;
}
}
workgroupBarrier();
}
}
let outputOffset = workgroup_id.z * uniforms.batchOffset + (workgroup_id.x * 32 + local_id.y) * uniforms.headNum
* uniforms.headSize + workgroup_id.y * uniforms.headSize + local_id.x;
for (var tile = 0; tile < ${numTiles}; tile++) {
for (var col = 0; col < ${colsPerThread}; col++) {
let outputIndx = outputOffset + u32(tile * ${bC}) + u32(col * 8);
output[outputIndx] = O_i[col * ${numTiles} + tile] / l_i_j;
}
}
}`;
};
return {
name: 'FlashAttentionV2',
shaderCache: { hint: `${numTiles}`, inputDependencies },
getRunData: () => ({ outputs, dispatchGroup: dispatch, programUniforms }),
getShaderSource,
};
};

const applyFlashAttentionV2 = (
context: ComputeContext,
q: TensorView,
k: TensorView,
v: TensorView,
parameters: AttentionParameters,
attributes: AttentionAttrs,
) => {
const inputs = [q, k, v];
context.compute(createFlashAttentionV2ProgramInfo(q, k, v, parameters, attributes), { inputs });
};

export const attention = (context: ComputeContext, attributes: AttentionAttrs): void => {
const params = validateAttentionInputs(context.inputs, attributes);

const [q, k, v] = prepare(context, params);

return applyAttention(
context,
q,
k,
v,
context.inputs[4],
undefined,
undefined,
undefined,
context.inputs[5],
params,
);
if (
params.sequenceLength >= 1024 &&
params.sequenceLength % 32 === 0 &&
params.headSize <= 128 &&
params.headSize % 32 === 0 &&
context.inputs[4] === undefined &&
context.inputs[5] === undefined
) {
return applyFlashAttentionV2(context, q, k, v, params, attributes);
} else {
return applyAttention(
context,
q,
k,
v,
context.inputs[4],
undefined,
undefined,
undefined,
context.inputs[5],
params,
);
}
};