Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
d190ac1
[WebGPU] Fuse Q/K RMSNorm into GroupQueryAttention decode
hariharans29 May 12, 2026
b9585c4
[WebGPU] Extract RunLayerNormProgram helper
hariharans29 May 12, 2026
6863729
[WebGPU] Tighten GroupQueryAttentionPreNormFusion preconditions
hariharans29 May 12, 2026
926437e
[GQA] Reject q_norm_weight / k_norm_weight inputs on CPU/CUDA/JSEP
hariharans29 May 12, 2026
98f4097
Merge remote-tracking branch 'origin' into hari/webgpu_perf_2
hariharans29 May 12, 2026
030d56c
[Test] Use custom checker for SkipsAlreadyFusedNode
hariharans29 May 12, 2026
86d61e5
[Test] Add <cmath> include for std::abs
hariharans29 May 12, 2026
c6965da
[Test] Fix stale comment on BuildQwenQkPostNormPattern
hariharans29 May 12, 2026
b4bd290
[Optimizer] Update GroupQueryAttentionPreNormFusion class doc to refl…
hariharans29 May 12, 2026
c263d6c
[Optimizer] Drop unused k{Q,K}NormWeightInputName constants
hariharans29 May 12, 2026
40986df
[WebGPU] Refresh stale q/k_norm_weight comment in GQA
hariharans29 May 12, 2026
6cbce92
[WebGPU] Clarify qk_norm_epsilon_ comment covers prefill fallback too
hariharans29 May 12, 2026
0fafb16
[Schema] Clarify q_norm_weight is honored by native WebGPU only
hariharans29 May 12, 2026
414684c
[Optimizer] Fix MatchPreNormReshapeChain doc comment (reshape order)
hariharans29 May 12, 2026
f247ae0
[Test] Fix uniform_int_distribution assertion in GQA pre-norm fusion …
hariharans29 May 13, 2026
21e09aa
Address Copilot review comments on PR #28484
hariharans29 May 15, 2026
1ee43af
Update docs
hariharans29 May 15, 2026
60debba
Merge branch 'main' into hari/webgpu_perf_2
hariharans29 May 20, 2026
c69d3f4
Merge remote-tracking branch 'origin/main' into hari/webgpu_perf_2
hariharans29 May 21, 2026
e5e45a3
Merge remote-tracking branch 'origin/main' into HEAD
hariharans29 May 26, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -2625,6 +2625,8 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>left_window_size for local attention (like Mistral). Default value is -1 meaning unused.</dd>
<dt><tt>num_heads</tt> : int (required)</dt>
<dd>Number of attention heads for q</dd>
<dt><tt>qk_norm_epsilon</tt> : float</dt>
<dd>Epsilon used by the per-head RMS norm applied to Q and K when q_norm_weight and k_norm_weight inputs are provided. Default value is 1e-6.</dd>
<dt><tt>qk_output</tt> : int</dt>
<dd>Output values of QK matrix multiplication before (1) or after (2) softmax normalization. Default value is 0 (don't output).</dd>
<dt><tt>rotary_interleaved</tt> : int</dt>
Expand All @@ -2639,7 +2641,7 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>Quantization type for V cache. One of 'NONE', 'PER_TENSOR', 'PER_CHANNEL'.</dd>
</dl>

#### Inputs (7 - 14)
#### Inputs (7 - 16)

<dl>
<dt><tt>query</tt> : T</dt>
Expand Down Expand Up @@ -2670,6 +2672,10 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>Scale tensor for past_key.</dd>
<dt><tt>v_scale</tt> (optional) : T_KV_SCALE</dt>
<dd>Scale tensor for past_value.</dd>
<dt><tt>q_norm_weight</tt> (optional) : T</dt>
<dd>Optional 1D tensor of shape (head_size). When provided together with k_norm_weight, the kernel applies a per-head RMS normalization to Q (and K) before any rotary embedding. Used by Qwen3-style models that wrap their Q/K projections in a Reshape -> SimplifiedLayerNormalization -> Reshape stack; downstream graph fusion folds that pattern into this input. Currently honored by the native WebGPU execution provider only; JSEP WebGPU/JS and other EPs must reject the node when this input is set.</dd>
<dt><tt>k_norm_weight</tt> (optional) : T</dt>
<dd>Optional 1D tensor of shape (head_size). See q_norm_weight. Must be provided together with q_norm_weight.</dd>
</dl>

#### Outputs (3 - 4)
Expand Down
4 changes: 2 additions & 2 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -1088,7 +1088,7 @@ The **OpSet Version** column uses the following notation:
|GreedySearch|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**I**<br> *in* prefix_vocab_mask:**I**<br> *in* attention_mask:**I**<br> *out* sequences:**I**|1+|**T** = tensor(float), tensor(float16)|
|GridSample|*in* X:**T1**<br> *in* Grid:**T1**<br> *out* Y:**T2**|1+|**T1** = tensor(float)<br/> **T2** = tensor(float)|
|GroupNorm|*in* X:**T**<br> *in* gamma:**M**<br> *in* beta:**M**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|GroupQueryAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* past_key:**T_CACHE**<br> *in* past_value:**T_CACHE**<br> *in* seqlens_k:**M**<br> *in* total_sequence_length:**M**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *in* position_ids:**tensor(int64)**<br> *in* attention_bias:**T**<br> *in* head_sink:**T**<br> *in* k_scale:**T_KV_SCALE**<br> *in* v_scale:**T_KV_SCALE**<br> *out* output:**T**<br> *out* present_key:**T_CACHE**<br> *out* present_value:**T_CACHE**<br> *out* output_qk:**T**|1+|**M** = tensor(int32)<br/> **T** = tensor(bfloat16), tensor(float16)<br/> **T_CACHE** = tensor(bfloat16), tensor(float16), tensor(float8e4m3fn), tensor(int8)<br/> **T_KV_SCALE** = tensor(float)|
|GroupQueryAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* past_key:**T_CACHE**<br> *in* past_value:**T_CACHE**<br> *in* seqlens_k:**M**<br> *in* total_sequence_length:**M**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *in* position_ids:**tensor(int64)**<br> *in* attention_bias:**T**<br> *in* head_sink:**T**<br> *in* k_scale:**T_KV_SCALE**<br> *in* v_scale:**T_KV_SCALE**<br> *in* q_norm_weight:**T**<br> *in* k_norm_weight:**T**<br> *out* output:**T**<br> *out* present_key:**T_CACHE**<br> *out* present_value:**T_CACHE**<br> *out* output_qk:**T**|1+|**M** = tensor(int32)<br/> **T** = tensor(bfloat16), tensor(float16)<br/> **T_CACHE** = tensor(bfloat16), tensor(float16), tensor(float8e4m3fn), tensor(int8)<br/> **T_KV_SCALE** = tensor(float)|
|Inverse|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|Irfft|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|LinearAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* past_state:**S**<br> *in* decay:**T**<br> *in* beta:**T**<br> *out* output:**T**<br> *out* present_state:**S**|1+|**T** = tensor(float), tensor(float16)|
Expand Down Expand Up @@ -1575,7 +1575,7 @@ The **OpSet Version** column uses the following notation:
|FusedMatMulActivation|*in* A:**T**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|Gelu|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|GroupNorm|*in* X:**T**<br> *in* gamma:**M**<br> *in* beta:**M**<br> *out* Y:**T**|1+|**M** = tensor(float), tensor(float16)<br/> **T** = tensor(float), tensor(float16)|
|GroupQueryAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* past_key:**T_CACHE**<br> *in* past_value:**T_CACHE**<br> *in* seqlens_k:**M**<br> *in* total_sequence_length:**M**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *in* position_ids:**tensor(int64)**<br> *in* attention_bias:**T**<br> *in* head_sink:**T**<br> *in* k_scale:**T_KV_SCALE**<br> *in* v_scale:**T_KV_SCALE**<br> *out* output:**T**<br> *out* present_key:**T_CACHE**<br> *out* present_value:**T_CACHE**<br> *out* output_qk:**T**|1+|**M** = tensor(int32)<br/> **T** = tensor(float), tensor(float16)|
|GroupQueryAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* past_key:**T_CACHE**<br> *in* past_value:**T_CACHE**<br> *in* seqlens_k:**M**<br> *in* total_sequence_length:**M**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *in* position_ids:**tensor(int64)**<br> *in* attention_bias:**T**<br> *in* head_sink:**T**<br> *in* k_scale:**T_KV_SCALE**<br> *in* v_scale:**T_KV_SCALE**<br> *in* q_norm_weight:**T**<br> *in* k_norm_weight:**T**<br> *out* output:**T**<br> *out* present_key:**T_CACHE**<br> *out* present_value:**T_CACHE**<br> *out* output_qk:**T**|1+|**M** = tensor(int32)<br/> **T** = tensor(float), tensor(float16)|
|MatMulIntegerToFloat|*in* A:**T1**<br> *in* B:**T2**<br> *in* a_scale:**T3**<br> *in* b_scale:**T3**<br> *in* a_zero_point:**T1**<br> *in* b_zero_point:**T2**<br> *in* bias:**T3**<br> *out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)<br/> **T2** = tensor(int8), tensor(uint8)<br/> **T3** = tensor(float), tensor(float16)|
|MatMulNBits|*in* A:**T1**<br> *in* B:**T2**<br> *in* scales:**T1**<br> *in* zero_points:**T3**<br> *in* g_idx:**T4**<br> *in* bias:**T1**<br> *out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)<br/> **T2** = tensor(uint8)|
|MultiHeadAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* bias:**T**<br> *in* key_padding_mask:**M**<br> *in* attention_bias:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *in* past_sequence_length:**M**<br> *in* cache_indirection:**M**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**<br> *out* qk:**QK**|1+|**M** = tensor(int32)<br/> **T** = tensor(float), tensor(float16)|
Expand Down
10 changes: 10 additions & 0 deletions js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,16 @@ const generatePositionIdsProgramInfo = (
};

export const groupQueryAttention = (context: ComputeContext, attributes: GroupQueryAttentionAttributes): void => {
// q_norm_weight (input 14) / k_norm_weight (input 15) are emitted by the WebGPU-only
// GroupQueryAttentionPreNormFusion optimizer pass. JSEP does not implement the fused
// per-head Q/K RMS normalization prologue, so reject the node if either input is present
// (regardless of rank, including scalars) rather than silently dropping the normalization.
if ((context.inputs.length > 14 && context.inputs[14]) || (context.inputs.length > 15 && context.inputs[15])) {
throw new Error(
'GroupQueryAttention (JSEP): q_norm_weight / k_norm_weight inputs are not supported. ' +
'The per-head Q/K RMS normalization prologue is implemented only on the native WebGPU EP.',
);
}
const params = validateInputs(context.inputs, attributes);
if (context.inputs[0].dims.length === 5) {
throw new Error('Packed QKV is not implemented');
Expand Down
12 changes: 12 additions & 0 deletions onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,18 @@ Status GroupQueryAttention<T>::Compute(OpKernelContext* context) const {
"kv_cache_bit_width must be 0 when quantization is disabled, got ", kv_cache_bit_width_);
}

// q_norm_weight (input 14) / k_norm_weight (input 15) are populated by the WebGPU-only
// GroupQueryAttentionPreNormFusion optimizer pass. The CPU kernel does not implement
// the fused per-head Q/K RMS normalization prologue, so reject the node if either input
// is present rather than silently dropping the normalization.
if ((context->InputCount() > 14 && context->Input<Tensor>(14) != nullptr) ||
(context->InputCount() > 15 && context->Input<Tensor>(15) != nullptr)) {
return ORT_MAKE_STATUS(
ONNXRUNTIME, INVALID_ARGUMENT,
"GroupQueryAttention (CPU): q_norm_weight / k_norm_weight inputs are not supported. "
"The per-head Q/K RMS normalization prologue is implemented only on the WebGPU EP.");
}

GroupQueryAttentionParameters parameters = {};
ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckInputs(query,
key,
Expand Down
12 changes: 12 additions & 0 deletions onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,18 @@ Status GroupQueryAttention<T, U>::ComputeInternal(OpKernelContext* context) cons
const Tensor* k_scale = context->Input<Tensor>(12);
const Tensor* v_scale = context->Input<Tensor>(13);

// q_norm_weight (input 14) / k_norm_weight (input 15) are populated by the WebGPU-only
// GroupQueryAttentionPreNormFusion optimizer pass. The CUDA kernel does not implement
// the fused per-head Q/K RMS normalization prologue, so reject the node if either input
// is present rather than silently dropping the normalization.
if ((context->InputCount() > 14 && context->Input<Tensor>(14) != nullptr) ||
(context->InputCount() > 15 && context->Input<Tensor>(15) != nullptr)) {
return ORT_MAKE_STATUS(
ONNXRUNTIME, INVALID_ARGUMENT,
"GroupQueryAttention (CUDA): q_norm_weight / k_norm_weight inputs are not supported. "
"The per-head Q/K RMS normalization prologue is implemented only on the WebGPU EP.");
}

if (k_quant_type_ != KVQuantizationType::NONE) {
if (k_scale == nullptr) {
return ORT_MAKE_STATUS(
Expand Down
Loading
Loading