[WebGPU] Support models with opset 24 ops and KV-shared decoder layers (Gemma 4)#28501
Merged
Conversation
Contributor
There was a problem hiding this comment.
Pull request overview
Adds WebGPU EP support for Gemma 4 E2B by registering opset 24 versions of Cast and Shape, and by extending the GroupQueryAttention kernel to handle the model's variable head_dim (e.g. 256) and KV-shared layers where kv_sequence_length==0 and present_key/present_value are optional.
Changes:
- Cast and Shape get explicit opset 23–23 versioned registrations plus new opset 24 registrations (kernels themselves unchanged).
- WebGPU GroupQueryAttention now accepts a missing
present_key/present_valueand a zero-length new KV input, routing the call through flash attention usingpast_key/past_valueas the KV context; rotary on Q is supported via a dummy K buffer. WebgpuAttentionParameters(GroupQueryAttentionParameters)now initializeskv_sequence_length_fromparameters.kv_sequence_length(previously the Q sequence length), andApplyFlashAttention's internal present K/V buffer shape useskv_num_heads_instead ofnum_heads_.- New WebGPU tests cross-check shared-KV decode/prompt/GQA-ratio/rotary against CPU.
Reviewed changes
Copilot reviewed 7 out of 7 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
| onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc | Splits Shape and Cast opset 23 registrations into 23–23 versioned and adds opset 24 registrations. |
| onnxruntime/core/providers/webgpu/tensor/shape_op.cc | Adds Shape kernel definition for opset 23–23 and bumps the open-ended registration to opset 24. |
| onnxruntime/core/providers/webgpu/tensor/cast.cc | Adds explicit CreateCastKernelInfo<23,23> and CreateCastKernelInfo<24> template instantiations. |
| onnxruntime/contrib_ops/webgpu/bert/attention_common.h | Fixes GQA kv_sequence_length_ to read from parameters.kv_sequence_length rather than Q's sequence length. |
| onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc | Uses kv_num_heads_ for internal present K/V shape; adds a kv_empty branch that aliases past_key/past_value as present_key/present_value via const_cast. |
| onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc | Drops the hard requirement on present K/V outputs, adds a kv_empty fast path that skips K/V processing (with dummy K for rotary), and errors out if the non-flash path is reached with shared KV. |
| onnxruntime/test/contrib_ops/group_query_attention_op_test.cc | Adds a use_webgpu parameter to two shared-KV test helpers and four new WebGPU shared-KV tests cross-checked against CPU. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
3 tasks
qjia7
reviewed
May 18, 2026
baijumeswani
pushed a commit
to microsoft/onnxruntime-genai
that referenced
this pull request
May 18, 2026
…ation (#2163) ## Summary - Fix heap corruption crash during WebGPU inference by using the correct allocator (`p_device_inputs_`) for tensors that serve as inputs to the embedding/decoder session - Previously, `embeddings.cpp` and `multi_modal_features.cpp` allocated these tensors on `p_device_` (WebGPU GPU memory), but the embedding session runs on CPU for WebGPU (without graph capture), causing the CPU session to write into GPU memory. - This is a pair change for this pr microsoft/onnxruntime#28501. ## Changes **`src/models/embeddings.cpp`** (lines 24, 52): - Changed `model_.p_device_->GetAllocator()` → `model_.p_device_inputs_->GetAllocator()` for embedding input tensor allocation and reallocation during sequence length updates **`src/models/multi_modal_features.cpp`** (lines 66, 85): - Changed `model_.p_device_->GetAllocator()` → `model_.p_device_inputs_->GetAllocator()` for empty feature tensors in `Update()` and `AllocateEmptyFeatures()` — these are inputs to the embedding session - Vision/speech model output allocations (lines 42, 98) correctly remain on `p_device_` ## Impact - **WebGPU**: Fixes crash (heap corruption 0xc0000374) during inference - **CUDA/DML/RyzenAI**: No-op — `p_device_inputs_ == p_device_` on these providers - **CPU**: No-op — `p_device_inputs_ == p_device_` on CPU ## Test plan - [x] End-to-end WebGPU inference with Gemma 4 E2B (109 tokens generated correctly) - [x] CPU inference still works after changes - [x] Output matches between CPU and WebGPU execution Co-authored-by: Claude Opus 4 <noreply@anthropic.com>
qjia7
reviewed
May 19, 2026
28f533d to
511301d
Compare
kunal-vaishnavi
previously approved these changes
May 26, 2026
Contributor
Author
|
Hi @guschmue, @hariharans29, can you help to review this pr, thanks. |
guschmue
previously approved these changes
May 27, 2026
Models exported at opset 24 (e.g. Gemma 4) require these registrations to avoid falling back to CPU for basic tensor operations. Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
- Size kDummy/kRotary with sequence_length (not 1) to match the kernel's iteration domain, preventing OOB writes during prefill (q_seq > 1) - Add ORT_ENFORCE in flash_attention kv_empty path to guard against future refactors that might write through the const_cast'd pointers - Add new test SharedKV_EmptyKV_WithPast_Rotary_Prompt_WebGPU exercising the rotary + kv_empty path with q_seq_len=6 Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
The internal present KV buffer shape must use num_heads_ for MHA (where kv_num_heads_ is 0) and kv_num_heads_ only for GQA. Using kv_num_heads_ unconditionally caused zero-sized buffers for MHA CrossAttention tests. Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
Remove duplicate tests (GQARatio8, Rotary_Prompt) that don't exercise distinct code paths. Rename remaining tests for clarity. Simplify kv_empty error message. Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
The rebase took main's version of RunGQASharedKV/RunGQASharedKVWithRotary which only had use_cuda. Add use_webgpu parameter to fix compile errors. Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
Dispatch over the Q output size instead of the full packed QKV input, reducing wasted threads by ~(hidden + 2*kv_hidden) / hidden ratio. Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
Packed QKV with kv_empty is not a valid combination for any existing model. Replace with ORT_ENFORCE to fail fast if encountered. Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
When total_sequence_length exceeds local_window_size, the sliding window check was blocking flash attention for kv_empty (shared KV) layers. This is incorrect because sliding window is irrelevant for these layers — they have no local KV cache and reuse another layer's already-computed cache. Add regression test for this case. Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
…fsetProgram Introduce RotaryEmbeddingWithOffsetProgram as a separate class that computes position from a uniform offset (position_offset + sequence_index) instead of requiring a position_ids tensor input. This avoids the need for RangeProgram dispatch and keeps the original RotaryEmbeddingProgram untouched. Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
…tProgram - WebGPU_SharedKV_Rotary_Prefill: q_seq_len=4, validates position_offset + bsnh[1] arithmetic for multiple sequence positions - WebGPU_SharedKV_Rotary_MultiBatch: batch_size=2, validates batch stride calculations Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
Remove RotaryEmbeddingWithOffsetProgram and instead pass a scalar position_ids tensor ([1,1] with value=past_sequence_length) to the existing RotaryEmbeddingProgram. The shader's broadcast logic already computes position_id = raw_pos + bsnh[1] when position_ids is scalar, eliminating the need for a separate class. Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
Instead of using RangeProgram to create a position_ids tensor, add a use_position_offset mode to RotaryEmbeddingProgram that computes position_id = position_offset + bsnh[1] directly in the shader. - Add use_position_offset_ flag and position_offset uniform to RotaryEmbeddingProgram (existing callers pass 0u, unused) - Merge shared rotation math in GenerateShaderCode to reduce duplication - Remove RangeProgram dependency from GQA kv_empty path - CacheHint differentiates the two compiled shader variants Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
The contrib RotaryEmbedding::ComputeInternal also uses RotaryEmbeddingProgram but was not updated to pass the 5th uniform. Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
Simpler to read with one top-level branch per mode rather than interleaved conditionals with shared code blocks. Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
4fb508c to
08240cb
Compare
Copilot stopped work on behalf of
feich-ms due to an error
June 2, 2026 05:42
The legacy caller passes inputs as [input, position_ids, cos_cache, sin_cache], but position_ids was being declared after cos_cache/sin_cache in the else branch, causing input slot mismatch and pipeline compilation failure. Hoist the conditional position_ids declaration to occur before cos_cache/sin_cache so the shader's input declaration order matches the caller's AddInputs order. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…X op
- Drop position_offset uniform; derive position_id in shader from seqlens[batch_idx] (GQA path), matching the FusedQKRotaryEmbedding approach. Saves a per-call uniform and removes redundant host-side offset bookkeeping.
- Promote RunRotaryEmbedding to a shared helper used by both group_query_attention.cc (kv_empty Q-only rotary) and rotary_embedding.cc (RotaryEmbedding::ComputeInternal). The use_seqlens_for_position flag selects the per-batch seqlens path vs. the legacy position_ids path.
- Rename the second program input to position_ids_or_seqlens for clearer semantics.
- Update the ONNX-domain RotaryEmbedding kernel to drop the now-removed {0u} position_offset uniform from both AddUniformVariables call sites.
qjia7
approved these changes
Jun 2, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
This PR adds WebGPU EP changes to support models with opset 24 ops and KV-shared decoder layers (like Gemma 4):
Opset 24 kernel registrations:
webgpu_execution_provider.ccregistration table accordinglyGQA (GroupQueryAttention) kv_empty path for KV-shared layers:
kv_sequence_length==0where layers reuse another layer's KV cache viapast_key/past_valueRotaryEmbeddingProgramthrough a sharedRunRotaryEmbeddinghelper. When invoked withuse_seqlens_for_position=true, the shader derivesposition_idper batch directly from theseqlenstensor (past_seqlen = seqlens[batch] + 1 - global_shape[1], thenposition_id = past_seqlen + sequence_index) — no host-side offset uniform and no scratch position_ids tensor are neededRunRotaryEmbeddingis now a single shared helper used by both the contrib GQA kernel (kv_empty Q-only rotary) and the standalone RotaryEmbedding kernel (ComputeInternal); the booleanuse_seqlens_for_positiontoggles between the seqlens-derived path and the legacyposition_idspathpresent_key/present_valueGPU allocation for kv_empty path (aliased to past instead)present_key/present_valueoutputs optional (nullptr when model doesn't request them)kv_sequence_length_initialization inWebgpuAttentionParameters(was incorrectly using Q sequence length)present_kv_headsto usekv_num_heads_instead ofnum_heads_for GQA internal buffer shapesTests:
WebGPU_SharedKV_Decode— kv_empty decode path (S=1)WebGPU_SharedKV_Prefill— kv_empty prefill path (S>1, tiled attention)WebGPU_SharedKV_Rotary— kv_empty with Q-only rotary embeddingWebGPU_SharedKV_Rotary_Prefill— Q-only rotary with multi-token prefill (q_seq=4)WebGPU_SharedKV_Rotary_MultiBatch— Q-only rotary with batch_size=2WebGPU_SharedKV_SlidingWindow— kv_empty with sliding window attentionAll WebGPU tests cross-check against CPU for numerical correctness.
Motivation and Context
Models with KV-shared decoder layers (e.g. Gemma 4 E2B: 20 of 35 layers share KV cache) pass
kv_sequence_length=0and empty K/V inputs to GQA nodes that reuse another layer's cached KV. The WebGPU GQA kernel previously rejected this pattern. These changes enable correct and efficient execution by:RunRotaryEmbeddinghelper withuse_seqlens_for_position=true, which lets the shader derive each token's position fromseqlens[batch_idx](K is already rotated in the shared cache)RotaryEmbeddingProgram— no new shader classes needed; the only addition is ause_seqlens_for_position_branch insideGenerateShaderCodethat selects between readingposition_idsand computingposition_idfromseqlensOpset 24 Cast/Shape registrations prevent fallback to CPU for models exported at opset 24.