[WebGPU] Fix MHA to ignore past key/value when no present outputs requested#28027
Merged
[WebGPU] Fix MHA to ignore past key/value when no present outputs requested#28027
Conversation
ba3be55 to
dde251f
Compare
Contributor
There was a problem hiding this comment.
Pull request overview
This PR fixes WebGPU EP MultiHeadAttention behavior to match CPU EP semantics by ensuring past key/value inputs are ignored when the node requests only the primary output (i.e., no present_key / present_value outputs). This prevents past KV from influencing attention results and avoids incorrect shader cache key specialization when present outputs aren’t requested.
Changes:
- Introduce
effectivePastKey/effectivePastValuethat are set toundefinedwhenoutputCount <= 1. - Route all downstream usage (program creation + input binding) through the effective past values so past tensors are never passed to shader creation/dispatch in the “no present outputs” case.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
guschmue
approved these changes
Apr 22, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
When MultiHeadAttention has only 1 output (no present_key/present_value outputs), past key/value inputs should be completely ignored, matching CPU EP semantics. The WebGPU EP was passing pastKey/pastValue TensorViews to shader creation functions even when outputCount <= 1, which affected shader cache keys and allowed past data to leak into the attention computation.
This caused the test "MultiHeadAttention Basic, one head and head-size=4 with pastKey and pastValue" to fail with output [17,18,19,20] (pastValue data) instead of expected [9,10,11,12] (V data). The failing output matches exactly what happens when past IS used: Q·pastKey=75 dominates Q·K=35, so softmax gives ~100% weight to pastValue.
Fix
In
applyAttention(), introduceeffectivePastKey/effectivePastValuethat are set toundefinedwhenoutputCount <= 1. All downstream usage (shader creation, input arrays) uses these effective values instead of the raw parameters. This ensures: