Skip to content

[WebGPU] Support models with opset 24 ops and KV-shared decoder layers (Gemma 4)#28501

Merged
feich-ms merged 30 commits into
mainfrom
user/feich/webgpu-gemma4-opset24
Jun 2, 2026
Merged

[WebGPU] Support models with opset 24 ops and KV-shared decoder layers (Gemma 4)#28501
feich-ms merged 30 commits into
mainfrom
user/feich/webgpu-gemma4-opset24

Conversation

@feich-ms

@feich-ms feich-ms commented May 14, 2026

Copy link
Copy Markdown
Contributor

Description

This PR adds WebGPU EP changes to support models with opset 24 ops and KV-shared decoder layers (like Gemma 4):

Opset 24 kernel registrations:

  • Cast op: version the opset 23 registration to 23-23 and add opset 24 registration
  • Shape op: version the opset 23 registration to 23-23 and add opset 24 registration
  • Updated webgpu_execution_provider.cc registration table accordingly

GQA (GroupQueryAttention) kv_empty path for KV-shared layers:

  • Support kv_sequence_length==0 where layers reuse another layer's KV cache via past_key/past_value
  • Apply Q-only rotary embedding by reusing RotaryEmbeddingProgram through a shared RunRotaryEmbedding helper. When invoked with use_seqlens_for_position=true, the shader derives position_id per batch directly from the seqlens tensor (past_seqlen = seqlens[batch] + 1 - global_shape[1], then position_id = past_seqlen + sequence_index) — no host-side offset uniform and no scratch position_ids tensor are needed
  • RunRotaryEmbedding is now a single shared helper used by both the contrib GQA kernel (kv_empty Q-only rotary) and the standalone RotaryEmbedding kernel (ComputeInternal); the boolean use_seqlens_for_position toggles between the seqlens-derived path and the legacy position_ids path
  • Skip wasteful present_key/present_value GPU allocation for kv_empty path (aliased to past instead)
  • Make present_key/present_value outputs optional (nullptr when model doesn't request them)
  • Fix kv_sequence_length_ initialization in WebgpuAttentionParameters (was incorrectly using Q sequence length)
  • Fix present_kv_heads to use kv_num_heads_ instead of num_heads_ for GQA internal buffer shapes

Tests:

  • WebGPU_SharedKV_Decode — kv_empty decode path (S=1)
  • WebGPU_SharedKV_Prefill — kv_empty prefill path (S>1, tiled attention)
  • WebGPU_SharedKV_Rotary — kv_empty with Q-only rotary embedding
  • WebGPU_SharedKV_Rotary_Prefill — Q-only rotary with multi-token prefill (q_seq=4)
  • WebGPU_SharedKV_Rotary_MultiBatch — Q-only rotary with batch_size=2
  • WebGPU_SharedKV_SlidingWindow — kv_empty with sliding window attention

All WebGPU tests cross-check against CPU for numerical correctness.

Motivation and Context

Models with KV-shared decoder layers (e.g. Gemma 4 E2B: 20 of 35 layers share KV cache) pass kv_sequence_length=0 and empty K/V inputs to GQA nodes that reuse another layer's cached KV. The WebGPU GQA kernel previously rejected this pattern. These changes enable correct and efficient execution by:

  1. Applying rotary to Q only via the shared RunRotaryEmbedding helper with use_seqlens_for_position=true, which lets the shader derive each token's position from seqlens[batch_idx] (K is already rotated in the shared cache)
  2. Aliasing past as present (avoiding unnecessary allocation and copy)
  3. Reusing the existing RotaryEmbeddingProgram — no new shader classes needed; the only addition is a use_seqlens_for_position_ branch inside GenerateShaderCode that selects between reading position_ids and computing position_id from seqlens

Opset 24 Cast/Shape registrations prevent fallback to CPU for models exported at opset 24.

@feich-ms feich-ms requested a review from Copilot May 14, 2026 03:56

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds WebGPU EP support for Gemma 4 E2B by registering opset 24 versions of Cast and Shape, and by extending the GroupQueryAttention kernel to handle the model's variable head_dim (e.g. 256) and KV-shared layers where kv_sequence_length==0 and present_key/present_value are optional.

Changes:

  • Cast and Shape get explicit opset 23–23 versioned registrations plus new opset 24 registrations (kernels themselves unchanged).
  • WebGPU GroupQueryAttention now accepts a missing present_key/present_value and a zero-length new KV input, routing the call through flash attention using past_key/past_value as the KV context; rotary on Q is supported via a dummy K buffer.
  • WebgpuAttentionParameters(GroupQueryAttentionParameters) now initializes kv_sequence_length_ from parameters.kv_sequence_length (previously the Q sequence length), and ApplyFlashAttention's internal present K/V buffer shape uses kv_num_heads_ instead of num_heads_.
  • New WebGPU tests cross-check shared-KV decode/prompt/GQA-ratio/rotary against CPU.

Reviewed changes

Copilot reviewed 7 out of 7 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc Splits Shape and Cast opset 23 registrations into 23–23 versioned and adds opset 24 registrations.
onnxruntime/core/providers/webgpu/tensor/shape_op.cc Adds Shape kernel definition for opset 23–23 and bumps the open-ended registration to opset 24.
onnxruntime/core/providers/webgpu/tensor/cast.cc Adds explicit CreateCastKernelInfo<23,23> and CreateCastKernelInfo<24> template instantiations.
onnxruntime/contrib_ops/webgpu/bert/attention_common.h Fixes GQA kv_sequence_length_ to read from parameters.kv_sequence_length rather than Q's sequence length.
onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc Uses kv_num_heads_ for internal present K/V shape; adds a kv_empty branch that aliases past_key/past_value as present_key/present_value via const_cast.
onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc Drops the hard requirement on present K/V outputs, adds a kv_empty fast path that skips K/V processing (with dummy K for rotary), and errors out if the non-flash path is reached with shared KV.
onnxruntime/test/contrib_ops/group_query_attention_op_test.cc Adds a use_webgpu parameter to two shared-KV test helpers and four new WebGPU shared-KV tests cross-checked against CPU.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc Outdated
Comment thread onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc Outdated
Comment thread onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc Outdated
Comment thread onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc Outdated
Comment thread onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc
baijumeswani pushed a commit to microsoft/onnxruntime-genai that referenced this pull request May 18, 2026
…ation (#2163)

## Summary

- Fix heap corruption crash during WebGPU inference by using the correct
allocator (`p_device_inputs_`) for tensors that serve as inputs to the
embedding/decoder session
- Previously, `embeddings.cpp` and `multi_modal_features.cpp` allocated
these tensors on `p_device_` (WebGPU GPU memory), but the embedding
session runs on CPU for WebGPU (without graph capture), causing the CPU
session to write into GPU memory.
- This is a pair change for this pr
microsoft/onnxruntime#28501.

## Changes

**`src/models/embeddings.cpp`** (lines 24, 52):
- Changed `model_.p_device_->GetAllocator()` →
`model_.p_device_inputs_->GetAllocator()` for embedding input tensor
allocation and reallocation during sequence length updates

**`src/models/multi_modal_features.cpp`** (lines 66, 85):
- Changed `model_.p_device_->GetAllocator()` →
`model_.p_device_inputs_->GetAllocator()` for empty feature tensors in
`Update()` and `AllocateEmptyFeatures()` — these are inputs to the
embedding session
- Vision/speech model output allocations (lines 42, 98) correctly remain
on `p_device_`

## Impact

- **WebGPU**: Fixes crash (heap corruption 0xc0000374) during inference
- **CUDA/DML/RyzenAI**: No-op — `p_device_inputs_ == p_device_` on these
providers
- **CPU**: No-op — `p_device_inputs_ == p_device_` on CPU

## Test plan

- [x] End-to-end WebGPU inference with Gemma 4 E2B (109 tokens generated
correctly)
- [x] CPU inference still works after changes
- [x] Output matches between CPU and WebGPU execution

Co-authored-by: Claude Opus 4 <noreply@anthropic.com>
Comment thread onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc Outdated
Comment thread onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc Outdated
Comment thread onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc Outdated
Comment thread onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc Outdated
@feich-ms feich-ms force-pushed the user/feich/webgpu-gemma4-opset24 branch from 28f533d to 511301d Compare May 25, 2026 02:08

@github-actions github-actions Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can commit the suggested changes from lintrunner.

Comment thread onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc Outdated
@feich-ms feich-ms changed the title [WebGPU] Add opset 24 support and Gemma 4 GQA enhancements [WebGPU] Add WebGPU EP changes to support models with opset 24 ops and KV-shared decoder layers May 25, 2026
@feich-ms feich-ms changed the title [WebGPU] Add WebGPU EP changes to support models with opset 24 ops and KV-shared decoder layers [WebGPU] Support models with opset 24 ops and KV-shared decoder layers (Gemma 4) May 25, 2026
@feich-ms feich-ms requested review from guschmue and hariharans29 and removed request for baijumeswani and kunal-vaishnavi May 26, 2026 05:32
@feich-ms

Copy link
Copy Markdown
Contributor Author

Hi @guschmue, @hariharans29, can you help to review this pr, thanks.

@github-actions github-actions Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can commit the suggested changes from lintrunner.

Comment thread onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc Outdated
Comment thread onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc Outdated
Comment thread onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc Outdated
Comment thread onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc Outdated
Comment thread onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc Outdated
guschmue
guschmue previously approved these changes May 27, 2026

@github-actions github-actions Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can commit the suggested changes from lintrunner.

Comment thread onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc Outdated

@github-actions github-actions Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can commit the suggested changes from lintrunner.

Comment thread onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc Outdated
Comment thread onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc Outdated
feich-ms and others added 3 commits June 2, 2026 13:04
Models exported at opset 24 (e.g. Gemma 4) require these registrations
to avoid falling back to CPU for basic tensor operations.

Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
- Size kDummy/kRotary with sequence_length (not 1) to match the kernel's
  iteration domain, preventing OOB writes during prefill (q_seq > 1)
- Add ORT_ENFORCE in flash_attention kv_empty path to guard against
  future refactors that might write through the const_cast'd pointers
- Add new test SharedKV_EmptyKV_WithPast_Rotary_Prompt_WebGPU exercising
  the rotary + kv_empty path with q_seq_len=6

Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
The internal present KV buffer shape must use num_heads_ for MHA
(where kv_num_heads_ is 0) and kv_num_heads_ only for GQA. Using
kv_num_heads_ unconditionally caused zero-sized buffers for MHA
CrossAttention tests.

Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
feich-ms and others added 16 commits June 2, 2026 13:04
Remove duplicate tests (GQARatio8, Rotary_Prompt) that don't exercise
distinct code paths. Rename remaining tests for clarity. Simplify
kv_empty error message.

Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
The rebase took main's version of RunGQASharedKV/RunGQASharedKVWithRotary
which only had use_cuda. Add use_webgpu parameter to fix compile errors.

Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
Dispatch over the Q output size instead of the full packed QKV input,
reducing wasted threads by ~(hidden + 2*kv_hidden) / hidden ratio.

Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
Packed QKV with kv_empty is not a valid combination for any existing
model. Replace with ORT_ENFORCE to fail fast if encountered.

Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
When total_sequence_length exceeds local_window_size, the sliding window
check was blocking flash attention for kv_empty (shared KV) layers. This
is incorrect because sliding window is irrelevant for these layers — they
have no local KV cache and reuse another layer's already-computed cache.
Add regression test for this case.

Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
…fsetProgram

Introduce RotaryEmbeddingWithOffsetProgram as a separate class that computes
position from a uniform offset (position_offset + sequence_index) instead of
requiring a position_ids tensor input. This avoids the need for RangeProgram
dispatch and keeps the original RotaryEmbeddingProgram untouched.

Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
…tProgram

- WebGPU_SharedKV_Rotary_Prefill: q_seq_len=4, validates position_offset + bsnh[1]
  arithmetic for multiple sequence positions
- WebGPU_SharedKV_Rotary_MultiBatch: batch_size=2, validates batch stride calculations

Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
Remove RotaryEmbeddingWithOffsetProgram and instead pass a scalar
position_ids tensor ([1,1] with value=past_sequence_length) to the
existing RotaryEmbeddingProgram. The shader's broadcast logic already
computes position_id = raw_pos + bsnh[1] when position_ids is scalar,
eliminating the need for a separate class.

Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
Instead of using RangeProgram to create a position_ids tensor, add a
use_position_offset mode to RotaryEmbeddingProgram that computes
position_id = position_offset + bsnh[1] directly in the shader.

- Add use_position_offset_ flag and position_offset uniform to
  RotaryEmbeddingProgram (existing callers pass 0u, unused)
- Merge shared rotation math in GenerateShaderCode to reduce duplication
- Remove RangeProgram dependency from GQA kv_empty path
- CacheHint differentiates the two compiled shader variants

Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
The contrib RotaryEmbedding::ComputeInternal also uses
RotaryEmbeddingProgram but was not updated to pass the 5th uniform.

Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
Simpler to read with one top-level branch per mode rather than
interleaved conditionals with shared code blocks.

Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
@feich-ms feich-ms force-pushed the user/feich/webgpu-gemma4-opset24 branch from 4fb508c to 08240cb Compare June 2, 2026 05:04

@github-actions github-actions Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can commit the suggested changes from lintrunner.

Comment thread onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc Outdated
feich-ms and others added 3 commits June 2, 2026 14:08
The legacy caller passes inputs as [input, position_ids, cos_cache,
sin_cache], but position_ids was being declared after cos_cache/sin_cache
in the else branch, causing input slot mismatch and pipeline compilation
failure.

Hoist the conditional position_ids declaration to occur before
cos_cache/sin_cache so the shader's input declaration order matches the
caller's AddInputs order.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…X op

- Drop position_offset uniform; derive position_id in shader from seqlens[batch_idx] (GQA path), matching the FusedQKRotaryEmbedding approach. Saves a per-call uniform and removes redundant host-side offset bookkeeping.

- Promote RunRotaryEmbedding to a shared helper used by both group_query_attention.cc (kv_empty Q-only rotary) and rotary_embedding.cc (RotaryEmbedding::ComputeInternal). The use_seqlens_for_position flag selects the per-batch seqlens path vs. the legacy position_ids path.

- Rename the second program input to position_ids_or_seqlens for clearer semantics.

- Update the ONNX-domain RotaryEmbedding kernel to drop the now-removed {0u} position_offset uniform from both AddUniformVariables call sites.
@feich-ms feich-ms enabled auto-merge (squash) June 2, 2026 09:33
@feich-ms feich-ms merged commit e1f27d1 into main Jun 2, 2026
86 checks passed
@feich-ms feich-ms deleted the user/feich/webgpu-gemma4-opset24 branch June 2, 2026 10:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ep:WebGPU ort-web webgpu provider

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants