Skip to content

Commit dde251f

Browse files
committed
Fix WebGPU MHA to ignore past key/value when no present outputs requested
1 parent 2bf09e9 commit dde251f

1 file changed

Lines changed: 9 additions & 6 deletions

File tree

js/web/lib/wasm/jsep/webgpu/ops/attention.ts

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -809,14 +809,17 @@ export const applyAttention = (
809809
) => {
810810
// Assumption is that presentKey/presentValue exists only if pastKey/pastValue exists.
811811
const outputCount = Math.min(context.outputCount, 1 + (pastKey ? 1 : 0) + (pastValue ? 1 : 0));
812+
// When there are no present key/value outputs (outputCount <= 1), ignore past to match CPU EP semantics.
813+
const effectivePastKey = outputCount > 1 ? pastKey : undefined;
814+
const effectivePastValue = outputCount > 1 ? pastValue : undefined;
812815
const pastSequenceLength = outputCount > 1 ? parameters.pastSequenceLength : 0;
813816
const totalSequenceLength = pastSequenceLength + parameters.kvSequenceLength;
814817
const attentionBias =
815818
attentionBiasInput && ShapeUtil.size(attentionBiasInput.dims) > 0 ? attentionBiasInput : undefined;
816819

817820
const inputsK = [q, k];
818-
if (outputCount > 1 && pastKey && ShapeUtil.size(pastKey.dims) > 0) {
819-
inputsK.push(pastKey);
821+
if (effectivePastKey && ShapeUtil.size(effectivePastKey.dims) > 0) {
822+
inputsK.push(effectivePastKey);
820823
}
821824
if (attentionBias) {
822825
inputsK.push(attentionBias);
@@ -833,7 +836,7 @@ export const applyAttention = (
833836
outputCount,
834837
q,
835838
k,
836-
pastKey,
839+
effectivePastKey,
837840
attentionBias,
838841
parameters,
839842
pastSequenceLength,
@@ -860,8 +863,8 @@ export const applyAttention = (
860863

861864
// Run AttentionScore
862865
const inputsV = [probs, v];
863-
if (outputCount > 1 && pastValue && ShapeUtil.size(pastValue.dims) > 0) {
864-
inputsV.push(pastValue);
866+
if (effectivePastValue && ShapeUtil.size(effectivePastValue.dims) > 0) {
867+
inputsV.push(effectivePastValue);
865868
}
866869
if (seqLens) {
867870
inputsV.push(seqLens);
@@ -874,7 +877,7 @@ export const applyAttention = (
874877
outputCount,
875878
probs,
876879
v,
877-
pastValue,
880+
effectivePastValue,
878881
parameters,
879882
pastSequenceLength,
880883
seqLens,

0 commit comments

Comments
 (0)