Skip to content

[WebGPU] Fix MHA to ignore past key/value when no present outputs requested#28027

Merged
vraspar merged 1 commit intomainfrom
vraspar/fix-webgpu-mha-past-ignored
Apr 27, 2026
Merged

[WebGPU] Fix MHA to ignore past key/value when no present outputs requested#28027
vraspar merged 1 commit intomainfrom
vraspar/fix-webgpu-mha-past-ignored

Conversation

@vraspar
Copy link
Copy Markdown
Contributor

@vraspar vraspar commented Apr 9, 2026

Description

When MultiHeadAttention has only 1 output (no present_key/present_value outputs), past key/value inputs should be completely ignored, matching CPU EP semantics. The WebGPU EP was passing pastKey/pastValue TensorViews to shader creation functions even when outputCount <= 1, which affected shader cache keys and allowed past data to leak into the attention computation.

This caused the test "MultiHeadAttention Basic, one head and head-size=4 with pastKey and pastValue" to fail with output [17,18,19,20] (pastValue data) instead of expected [9,10,11,12] (V data). The failing output matches exactly what happens when past IS used: Q·pastKey=75 dominates Q·K=35, so softmax gives ~100% weight to pastValue.

Fix

In applyAttention(), introduce effectivePastKey/effectivePastValue that are set to undefined when outputCount <= 1. All downstream usage (shader creation, input arrays) uses these effective values instead of the raw parameters. This ensures:

  • Shader cache keys correctly reflect the "no past" configuration
  • Past tensors are never passed to any shader creation function
  • Behavior matches CPU EP (which ignores past when present outputs are null)
  • GQA is unaffected (always has outputCount >= 3)
  • Vanilla Attention is unaffected (always passes undefined for past)

@vraspar vraspar force-pushed the vraspar/fix-webgpu-mha-past-ignored branch from ba3be55 to dde251f Compare April 9, 2026 20:42
@vraspar vraspar requested a review from Copilot April 9, 2026 20:49
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR fixes WebGPU EP MultiHeadAttention behavior to match CPU EP semantics by ensuring past key/value inputs are ignored when the node requests only the primary output (i.e., no present_key / present_value outputs). This prevents past KV from influencing attention results and avoids incorrect shader cache key specialization when present outputs aren’t requested.

Changes:

  • Introduce effectivePastKey / effectivePastValue that are set to undefined when outputCount <= 1.
  • Route all downstream usage (program creation + input binding) through the effective past values so past tensors are never passed to shader creation/dispatch in the “no present outputs” case.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@vraspar vraspar requested review from edgchen1 and fs-eire April 9, 2026 21:18
@vraspar vraspar requested review from guschmue and removed request for fs-eire April 22, 2026 20:31
@vraspar vraspar merged commit 3c94f1c into main Apr 27, 2026
120 of 124 checks passed
@vraspar vraspar deleted the vraspar/fix-webgpu-mha-past-ignored branch April 27, 2026 18:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants