diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index 6df316097e719..d830585a708a8 100644
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -2625,6 +2625,8 @@ This version of the operator has been available since version 1 of the 'com.micr
left_window_size for local attention (like Mistral). Default value is -1 meaning unused.
num_heads : int (required)
Number of attention heads for q
+qk_norm_epsilon : float
+Epsilon used by the per-head RMS norm applied to Q and K when q_norm_weight and k_norm_weight inputs are provided. Default value is 1e-6.
qk_output : int
Output values of QK matrix multiplication before (1) or after (2) softmax normalization. Default value is 0 (don't output).
rotary_interleaved : int
@@ -2639,7 +2641,7 @@ This version of the operator has been available since version 1 of the 'com.micr
Quantization type for V cache. One of 'NONE', 'PER_TENSOR', 'PER_CHANNEL'.
-#### Inputs (7 - 14)
+#### Inputs (7 - 16)
- query : T
@@ -2670,6 +2672,10 @@ This version of the operator has been available since version 1 of the 'com.micr
- Scale tensor for past_key.
- v_scale (optional) : T_KV_SCALE
- Scale tensor for past_value.
+- q_norm_weight (optional) : T
+- Optional 1D tensor of shape (head_size). When provided together with k_norm_weight, the kernel applies a per-head RMS normalization to Q (and K) before any rotary embedding. Used by Qwen3-style models that wrap their Q/K projections in a Reshape -> SimplifiedLayerNormalization -> Reshape stack; downstream graph fusion folds that pattern into this input. Currently honored by the native WebGPU execution provider only; JSEP WebGPU/JS and other EPs must reject the node when this input is set.
+- k_norm_weight (optional) : T
+- Optional 1D tensor of shape (head_size). See q_norm_weight. Must be provided together with q_norm_weight.
#### Outputs (3 - 4)
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index d0d8e750285d4..4a2399c000f46 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -1088,7 +1088,7 @@ The **OpSet Version** column uses the following notation:
|GreedySearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*out* sequences:**I**|1+|**T** = tensor(float), tensor(float16)|
|GridSample|*in* X:**T1**
*in* Grid:**T1**
*out* Y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(float)|
|GroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
-|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T_CACHE**
*in* past_value:**T_CACHE**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**tensor(int64)**
*in* attention_bias:**T**
*in* head_sink:**T**
*in* k_scale:**T_KV_SCALE**
*in* v_scale:**T_KV_SCALE**
*out* output:**T**
*out* present_key:**T_CACHE**
*out* present_value:**T_CACHE**
*out* output_qk:**T**|1+|**M** = tensor(int32)
**T** = tensor(bfloat16), tensor(float16)
**T_CACHE** = tensor(bfloat16), tensor(float16), tensor(float8e4m3fn), tensor(int8)
**T_KV_SCALE** = tensor(float)|
+|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T_CACHE**
*in* past_value:**T_CACHE**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**tensor(int64)**
*in* attention_bias:**T**
*in* head_sink:**T**
*in* k_scale:**T_KV_SCALE**
*in* v_scale:**T_KV_SCALE**
*in* q_norm_weight:**T**
*in* k_norm_weight:**T**
*out* output:**T**
*out* present_key:**T_CACHE**
*out* present_value:**T_CACHE**
*out* output_qk:**T**|1+|**M** = tensor(int32)
**T** = tensor(bfloat16), tensor(float16)
**T_CACHE** = tensor(bfloat16), tensor(float16), tensor(float8e4m3fn), tensor(int8)
**T_KV_SCALE** = tensor(float)|
|Inverse|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|Irfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|LinearAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_state:**S**
*in* decay:**T**
*in* beta:**T**
*out* output:**T**
*out* present_state:**S**|1+|**T** = tensor(float), tensor(float16)|
@@ -1575,7 +1575,7 @@ The **OpSet Version** column uses the following notation:
|FusedMatMulActivation|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|Gelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|GroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*out* Y:**T**|1+|**M** = tensor(float), tensor(float16)
**T** = tensor(float), tensor(float16)|
-|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T_CACHE**
*in* past_value:**T_CACHE**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**tensor(int64)**
*in* attention_bias:**T**
*in* head_sink:**T**
*in* k_scale:**T_KV_SCALE**
*in* v_scale:**T_KV_SCALE**
*out* output:**T**
*out* present_key:**T_CACHE**
*out* present_value:**T_CACHE**
*out* output_qk:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)|
+|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T_CACHE**
*in* past_value:**T_CACHE**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**tensor(int64)**
*in* attention_bias:**T**
*in* head_sink:**T**
*in* k_scale:**T_KV_SCALE**
*in* v_scale:**T_KV_SCALE**
*in* q_norm_weight:**T**
*in* k_norm_weight:**T**
*out* output:**T**
*out* present_key:**T_CACHE**
*out* present_value:**T_CACHE**
*out* output_qk:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)|
|MatMulIntegerToFloat|*in* A:**T1**
*in* B:**T2**
*in* a_scale:**T3**
*in* b_scale:**T3**
*in* a_zero_point:**T1**
*in* b_zero_point:**T2**
*in* bias:**T3**
*out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(float), tensor(float16)|
|MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T3**
*in* g_idx:**T4**
*in* bias:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(uint8)|
|MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* attention_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* past_sequence_length:**M**
*in* cache_indirection:**M**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**
*out* qk:**QK**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)|
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts
index 9050c1bbb8816..ada0c65bdd8a7 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts
@@ -328,6 +328,16 @@ const generatePositionIdsProgramInfo = (
};
export const groupQueryAttention = (context: ComputeContext, attributes: GroupQueryAttentionAttributes): void => {
+ // q_norm_weight (input 14) / k_norm_weight (input 15) are emitted by the WebGPU-only
+ // GroupQueryAttentionPreNormFusion optimizer pass. JSEP does not implement the fused
+ // per-head Q/K RMS normalization prologue, so reject the node if either input is present
+ // (regardless of rank, including scalars) rather than silently dropping the normalization.
+ if ((context.inputs.length > 14 && context.inputs[14]) || (context.inputs.length > 15 && context.inputs[15])) {
+ throw new Error(
+ 'GroupQueryAttention (JSEP): q_norm_weight / k_norm_weight inputs are not supported. ' +
+ 'The per-head Q/K RMS normalization prologue is implemented only on the native WebGPU EP.',
+ );
+ }
const params = validateInputs(context.inputs, attributes);
if (context.inputs[0].dims.length === 5) {
throw new Error('Packed QKV is not implemented');
diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc
index 4df5f6a349599..8eb7c73f8a4a9 100644
--- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc
+++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc
@@ -84,6 +84,18 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const {
"kv_cache_bit_width must be 0 when quantization is disabled, got ", kv_cache_bit_width_);
}
+ // q_norm_weight (input 14) / k_norm_weight (input 15) are populated by the WebGPU-only
+ // GroupQueryAttentionPreNormFusion optimizer pass. The CPU kernel does not implement
+ // the fused per-head Q/K RMS normalization prologue, so reject the node if either input
+ // is present rather than silently dropping the normalization.
+ if ((context->InputCount() > 14 && context->Input(14) != nullptr) ||
+ (context->InputCount() > 15 && context->Input(15) != nullptr)) {
+ return ORT_MAKE_STATUS(
+ ONNXRUNTIME, INVALID_ARGUMENT,
+ "GroupQueryAttention (CPU): q_norm_weight / k_norm_weight inputs are not supported. "
+ "The per-head Q/K RMS normalization prologue is implemented only on the WebGPU EP.");
+ }
+
GroupQueryAttentionParameters parameters = {};
ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckInputs(query,
key,
diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc
index 44408e6ce4af9..ea84fb973091c 100644
--- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc
+++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc
@@ -167,6 +167,18 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons
const Tensor* k_scale = context->Input(12);
const Tensor* v_scale = context->Input(13);
+ // q_norm_weight (input 14) / k_norm_weight (input 15) are populated by the WebGPU-only
+ // GroupQueryAttentionPreNormFusion optimizer pass. The CUDA kernel does not implement
+ // the fused per-head Q/K RMS normalization prologue, so reject the node if either input
+ // is present rather than silently dropping the normalization.
+ if ((context->InputCount() > 14 && context->Input(14) != nullptr) ||
+ (context->InputCount() > 15 && context->Input(15) != nullptr)) {
+ return ORT_MAKE_STATUS(
+ ONNXRUNTIME, INVALID_ARGUMENT,
+ "GroupQueryAttention (CUDA): q_norm_weight / k_norm_weight inputs are not supported. "
+ "The per-head Q/K RMS normalization prologue is implemented only on the WebGPU EP.");
+ }
+
if (k_quant_type_ != KVQuantizationType::NONE) {
if (k_scale == nullptr) {
return ORT_MAKE_STATUS(
diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc
index e3b91bdbb82f4..7fd1271e37132 100644
--- a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc
+++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc
@@ -9,6 +9,7 @@
#include "contrib_ops/webgpu/bert/flash_attention.h"
#include "core/common/narrow.h"
+#include "core/providers/webgpu/nn/layer_norm.h"
#include "core/providers/webgpu/webgpu_supported_types.h"
#include "core/providers/webgpu/shader_helper.h"
@@ -104,7 +105,11 @@ Status RunSplitPackedQKVWithRotaryEmbedding(onnxruntime::webgpu::ComputeContext&
return context.RunProgram(program);
}
-// Fused Q/K rotary embedding
+// Fused Q/K rotary embedding. When q_norm_weight and k_norm_weight are non-null, a per-head
+// RMS normalization (Q[c] *= inverseSqrt(mean(Q[..]^2)+eps) * q_norm_weight[c]; same for K)
+// is fused into the rotary kernel ahead of the rotation. This decode-only fast path replaces
+// the standalone SimplifiedLayerNormalization dispatches that GroupQueryAttentionPreNormFusion
+// folds away.
Status RunFusedQKRotaryEmbedding(onnxruntime::webgpu::ComputeContext& context,
const WebgpuAttentionParameters& params,
const Tensor* query_in,
@@ -113,7 +118,10 @@ Status RunFusedQKRotaryEmbedding(onnxruntime::webgpu::ComputeContext& context,
const Tensor* cos_cache,
const Tensor* sin_cache,
Tensor* query_out,
- Tensor* key_out) {
+ Tensor* key_out,
+ const Tensor* q_norm_weight = nullptr,
+ const Tensor* k_norm_weight = nullptr,
+ float qk_norm_epsilon = 0.0f) {
const auto half_rotary_embedding_dim = gsl::narrow_cast(cos_cache->Shape()[1]);
const auto head_size = params.head_size_;
@@ -155,9 +163,10 @@ Status RunFusedQKRotaryEmbedding(onnxruntime::webgpu::ComputeContext& context,
1u});
// Dispatch computations only over the Q domain, and fuse K write operations using a head-index-based condition.
- FusedQKRotaryEmbeddingProgram program(params.rotary_interleaved_);
+ const bool has_qk_norm = (q_norm_weight != nullptr) && (k_norm_weight != nullptr);
+ FusedQKRotaryEmbeddingProgram program(params.rotary_interleaved_, has_qk_norm);
program
- .CacheHint(params.rotary_interleaved_)
+ .CacheHint(params.rotary_interleaved_, has_qk_norm)
.AddInputs({
{query_in, ProgramTensorMetadataDependency::TypeAndRank},
{key_in, ProgramTensorMetadataDependency::Rank},
@@ -178,8 +187,17 @@ Status RunFusedQKRotaryEmbedding(onnxruntime::webgpu::ComputeContext& context,
{gsl::make_span(k_global_dims)},
{gsl::make_span(k_input_output_strides)},
{q_domain_size},
+ {static_cast(head_size)},
+ {qk_norm_epsilon},
});
+ if (has_qk_norm) {
+ program.AddInputs({
+ {q_norm_weight, ProgramTensorMetadataDependency::Type},
+ {k_norm_weight, ProgramTensorMetadataDependency::Type},
+ });
+ }
+
return context.RunProgram(program);
}
@@ -196,6 +214,20 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext&
const Tensor* position_ids = context.Input(9); // TODO: support sliding window
const Tensor* attention_bias = context.Input(10);
const Tensor* head_sink = context.Input(11);
+ // Inputs 12 and 13 are k_scale / v_scale (KV-cache quant). Not consumed by WebGPU yet.
+ // Inputs 14 and 15 are q_norm_weight / k_norm_weight, populated by
+ // GroupQueryAttentionPreNormFusion. WebGPU supports these inputs for the configurations
+ // validated below (do_rotary, non-packed Q/K/V).
+ const Tensor* q_norm_weight = context.InputCount() > 14 ? context.Input(14) : nullptr;
+ const Tensor* k_norm_weight = context.InputCount() > 15 ? context.Input(15) : nullptr;
+ const bool has_qk_norm = (q_norm_weight != nullptr) && (k_norm_weight != nullptr);
+ // The current fused prologue only supports the Qwen3-style configuration that
+ // GroupQueryAttentionPreNormFusion targets: do_rotary, non-packed Q/K/V. Reject any
+ // other configuration so downstream rewrites cannot land silently.
+ ORT_RETURN_IF(((q_norm_weight != nullptr) ^ (k_norm_weight != nullptr)),
+ "GroupQueryAttention: q_norm_weight and k_norm_weight must be provided together.");
+ ORT_RETURN_IF(has_qk_norm && !do_rotary_,
+ "GroupQueryAttention: q_norm_weight / k_norm_weight require do_rotary=1.");
GroupQueryAttentionParameters params = {};
ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckInputs(query,
@@ -227,6 +259,23 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext&
static_cast(Info().GetAttrOrDefault("qk_output", static_cast(QKOutputType::NO_OUTPUT)))));
WebgpuAttentionParameters parameters(params);
+ ORT_RETURN_IF(has_qk_norm && parameters.is_packed_qkv_,
+ "GroupQueryAttention: q_norm_weight / k_norm_weight are not supported when QKV is packed.");
+ if (has_qk_norm) {
+ // The fused prologue indexes q/k_norm_weight as a 1-D tensor of length head_size. Validate
+ // shape here so a hand-authored model with a wrong shape fails with INVALID_ARGUMENT instead
+ // of silently reading the wrong offsets (or out of bounds).
+ const auto& q_norm_shape = q_norm_weight->Shape();
+ ORT_RETURN_IF_NOT(q_norm_shape.NumDimensions() == 1 &&
+ q_norm_shape[0] == static_cast(parameters.head_size_),
+ "GroupQueryAttention: q_norm_weight must be a 1-D tensor of shape [head_size=",
+ parameters.head_size_, "], got ", q_norm_shape.ToString(), ".");
+ const auto& k_norm_shape = k_norm_weight->Shape();
+ ORT_RETURN_IF_NOT(k_norm_shape.NumDimensions() == 1 &&
+ k_norm_shape[0] == static_cast(parameters.head_size_),
+ "GroupQueryAttention: k_norm_weight must be a 1-D tensor of shape [head_size=",
+ parameters.head_size_, "], got ", k_norm_shape.ToString(), ".");
+ }
TensorShapeVector output_shape(3);
output_shape[0] = static_cast(parameters.batch_size_);
output_shape[1] = static_cast(parameters.sequence_length_);
@@ -304,16 +353,63 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext&
value = &vSplit;
}
if (do_rotary_) {
+ // Per-head RMS normalization handling for Qwen3-style models (GQA inputs 14/15).
+ // - Decode (sequence_length == 1): fold the norm into the FusedQKRotaryEmbedding
+ // kernel. Each thread re-reads its head's head_size channels (Approach A); no
+ // reductions, no shared memory. Sub-microsecond overhead vs ~60us/layer SLN savings.
+ // - Prefill (sequence_length > 1): fall back to two standalone SimplifiedLayerNorm
+ // dispatches into scratch tensors, then run the unfused FusedQKRotaryEmbedding.
+ // Matches the pre-fusion graph timing exactly so prefill cannot regress.
+ Tensor qNorm;
+ Tensor kNorm;
+ const Tensor* q_for_rotary = query;
+ const Tensor* k_for_rotary = key;
+ const Tensor* q_norm_for_fused = nullptr;
+ const Tensor* k_norm_for_fused = nullptr;
+ const bool decode_norm_fast_path = has_qk_norm && parameters.sequence_length_ == 1;
+ if (has_qk_norm && !decode_norm_fast_path) {
+ qNorm = context.CreateGPUTensor(query->DataType(), query->Shape());
+ kNorm = context.CreateGPUTensor(key->DataType(), key->Shape());
+ const uint32_t q_norm_count =
+ static_cast(parameters.batch_size_) *
+ static_cast(parameters.sequence_length_) *
+ static_cast(parameters.num_heads_);
+ const uint32_t k_norm_count =
+ static_cast(parameters.batch_size_) *
+ static_cast(parameters.sequence_length_) *
+ static_cast(parameters.kv_num_heads_);
+ ORT_RETURN_IF_ERROR(onnxruntime::webgpu::RunLayerNormProgram(
+ context, query, q_norm_weight, /*bias=*/nullptr, qk_norm_epsilon_,
+ q_norm_count, static_cast(parameters.head_size_),
+ /*simplified=*/true, &qNorm, /*mean=*/nullptr, /*inv_std_dev=*/nullptr));
+ ORT_RETURN_IF_ERROR(onnxruntime::webgpu::RunLayerNormProgram(
+ context, key, k_norm_weight, /*bias=*/nullptr, qk_norm_epsilon_,
+ k_norm_count, static_cast(parameters.head_size_),
+ /*simplified=*/true, &kNorm, /*mean=*/nullptr, /*inv_std_dev=*/nullptr));
+ q_for_rotary = &qNorm;
+ k_for_rotary = &kNorm;
+ } else if (decode_norm_fast_path) {
+ q_norm_for_fused = q_norm_weight;
+ k_norm_for_fused = k_norm_weight;
+ }
// rotary QK
- qRotary = context.CreateGPUTensor(query->DataType(), query->Shape());
- kRotary = context.CreateGPUTensor(key->DataType(), key->Shape());
+ qRotary = context.CreateGPUTensor(q_for_rotary->DataType(), q_for_rotary->Shape());
+ kRotary = context.CreateGPUTensor(k_for_rotary->DataType(), k_for_rotary->Shape());
ORT_RETURN_IF_ERROR(RunFusedQKRotaryEmbedding(context, parameters,
- query, key,
+ q_for_rotary, k_for_rotary,
seqlen_k,
cos_cache, sin_cache,
- &qRotary, &kRotary));
+ &qRotary, &kRotary,
+ q_norm_for_fused, k_norm_for_fused,
+ qk_norm_epsilon_));
query = &qRotary;
key = &kRotary;
+ } else if (has_qk_norm) {
+ // Defensive: do_rotary_ guard above should make this unreachable, but keep it
+ // explicit so a future schema/config drift surfaces as a clear error.
+ return ORT_MAKE_STATUS(
+ ONNXRUNTIME, NOT_IMPLEMENTED,
+ "GroupQueryAttention: q/k norm weights require do_rotary=1 (no rotary, no norm path).");
}
}
diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.h b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.h
index 4127a8928f38e..cbb5b806eb6ad 100644
--- a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.h
+++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.h
@@ -58,6 +58,8 @@ class GroupQueryAttention final : public WebGpuKernel {
use_smooth_softmax_ = info.GetAttrOrDefault("smooth_softmax", 0) == 1;
local_window_size_ = static_cast(info.GetAttrOrDefault("local_window_size", -1));
+
+ qk_norm_epsilon_ = info.GetAttrOrDefault("qk_norm_epsilon", 1e-6f);
}
int num_heads_; // number of attention heads of Q
@@ -69,6 +71,10 @@ class GroupQueryAttention final : public WebGpuKernel {
int local_window_size_;
bool use_smooth_softmax_;
+ // Epsilon used by per-head RMSNorm when q_norm_weight / k_norm_weight (inputs 14 / 15) are
+ // provided. Consumed whenever those optional norm inputs are used (decode fast path or
+ // prefill fallback), and ignored otherwise.
+ float qk_norm_epsilon_;
Status ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const override;
};
diff --git a/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc b/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc
index 69d2db391ce3c..58f7b54bd8840 100644
--- a/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc
+++ b/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc
@@ -66,22 +66,95 @@ Status RotaryEmbeddingProgram::GenerateShaderCode(ShaderHelper& shader) const {
}
Status FusedQKRotaryEmbeddingProgram::GenerateShaderCode(ShaderHelper& shader) const {
- // Inputs
- const auto& q_input = shader.AddInput("q_input", ShaderUsage::UseUniform);
- const auto& k_input = shader.AddInput("k_input", ShaderUsage::UseUniform);
+ // Inputs. q_input/k_input use the element-type alias when has_qk_norm_ is true so we can
+ // mix in the f32-computed inverse-RMS scale at element-type precision.
+ const ShaderUsage qk_input_usage = has_qk_norm_
+ ? (ShaderUsage::UseUniform | ShaderUsage::UseElementTypeAlias)
+ : ShaderUsage::UseUniform;
+ const auto& q_input = shader.AddInput("q_input", qk_input_usage);
+ const auto& k_input = shader.AddInput("k_input", qk_input_usage);
const auto& seqlens = shader.AddInput("seqlens", ShaderUsage::UseUniform);
const auto& cos_cache = shader.AddInput("cos_cache", ShaderUsage::UseUniform);
const auto& sin_cache = shader.AddInput("sin_cache", ShaderUsage::UseUniform);
+
+ // Optional per-head RMS norm weights (1D tensors of length head_size). When present,
+ // a fused per-head normalization is applied to Q/K before the rotary rotation:
+ // x_norm[c] = x[c] * inverseSqrt(mean(x[..]^2) + epsilon) * weight[c]
+ // Decode-only fast path: each thread re-reads its own head's head_size channels to
+ // compute the sum-of-squares (no reductions, no shared memory). The redundant L1
+ // traffic is sub-microsecond on Qwen3-1.7B decode geometry.
+ if (has_qk_norm_) {
+ shader.AddInput("q_norm_weight", ShaderUsage::UseUniform);
+ shader.AddInput("k_norm_weight", ShaderUsage::UseUniform);
+ }
+
// Outputs
const auto& q_output = shader.AddOutput("q_output", ShaderUsage::UseUniform);
const auto& k_output = shader.AddOutput("k_output", ShaderUsage::UseUniform);
const auto interleaved_str = interleaved_ ? "true" : "false";
- shader.MainFunctionBody()
+ auto& body = shader.MainFunctionBody();
+ body
<< " if (global_idx >= uniforms.q_domain_size) { return; }\n"
<< " let half_rotary_dim = uniforms.cos_cache_shape[1];\n"
<< " let bsnh = global_idx / uniforms.q_global_stride % uniforms.q_global_shape;\n"
+ << " let needs_k = bsnh[2] < uniforms.k_global_shape[2];\n";
+
+ // Per-head RMS computation (Approach A, no reductions). For non-interleaved layouts the
+ // bsnh[3] coordinate is the lower channel of a rotary pair, so the head base offset is
+ // dot(bsnh, stride) - bsnh[3] (i.e. drop the channel contribution). q_input_output_stride[3]
+ // is always 1 (channel stride), so subtracting bsnh[3] gives the head's channel-0 offset
+ // for both interleaved and non-interleaved layouts in the rotated branch. In the
+ // passthrough else-branch we recompute from bsnh[0..2] explicitly.
+ if (has_qk_norm_) {
+ body
+ << " let q_head_base = bsnh[0] * uniforms.q_input_output_stride[0]\n"
+ << " + bsnh[1] * uniforms.q_input_output_stride[1]\n"
+ << " + bsnh[2] * uniforms.q_input_output_stride[2];\n"
+ << " var q_sumsq: f32 = 0.0;\n"
+ << " for (var c: u32 = 0u; c < uniforms.head_size; c = c + 1u) {\n"
+ << " let q_v = f32(" << q_input.GetByOffset("q_head_base + c") << ");\n"
+ << " q_sumsq = q_sumsq + q_v * q_v;\n"
+ << " }\n"
+ << " let q_inv_rms = q_input_element_t(inverseSqrt(q_sumsq / f32(uniforms.head_size) + uniforms.qk_norm_epsilon));\n"
+ << " let k_head_base = bsnh[0] * uniforms.k_input_output_stride[0]\n"
+ << " + bsnh[1] * uniforms.k_input_output_stride[1]\n"
+ << " + bsnh[2] * uniforms.k_input_output_stride[2];\n"
+ << " var k_inv_rms = k_input_element_t(0);\n"
+ << " if (needs_k) {\n"
+ << " var k_sumsq: f32 = 0.0;\n"
+ << " for (var c: u32 = 0u; c < uniforms.head_size; c = c + 1u) {\n"
+ << " let k_v = f32(" << k_input.GetByOffset("k_head_base + c") << ");\n"
+ << " k_sumsq = k_sumsq + k_v * k_v;\n"
+ << " }\n"
+ << " k_inv_rms = k_input_element_t(inverseSqrt(k_sumsq / f32(uniforms.head_size) + uniforms.qk_norm_epsilon));\n"
+ << " }\n";
+ }
+
+ // Helpers that load Q/K and (when has_qk_norm_) apply the fused per-channel norm scale.
+ // The channel index expressions match the qi/qj/ki/kj/qk/kk computations used below.
+ auto load_q = [&](const std::string& off, const std::string& chan) {
+ if (!has_qk_norm_) {
+ return q_input.GetByOffset(off);
+ }
+ return std::string("(") + q_input.GetByOffset(off) + " * q_inv_rms * q_norm_weight[" + chan + "])";
+ };
+ auto load_k = [&](const std::string& off, const std::string& chan) {
+ if (!has_qk_norm_) {
+ return k_input.GetByOffset(off);
+ }
+ return std::string("(") + k_input.GetByOffset(off) + " * k_inv_rms * k_norm_weight[" + chan + "])";
+ };
+
+ // Channel index expressions for the rotated branch. For interleaved layout the pair is
+ // (2*bsnh[3], 2*bsnh[3]+1); otherwise it is (bsnh[3], bsnh[3]+half_rotary_dim).
+ const std::string c_i = interleaved_ ? "(2u * bsnh[3])" : "bsnh[3]";
+ const std::string c_j = interleaved_ ? "(2u * bsnh[3] + 1u)" : "(bsnh[3] + half_rotary_dim)";
+ // Channel index for the passthrough else-branch (only fires when head_size > 2 * half_rotary_dim).
+ const std::string c_k = "(bsnh[3] + half_rotary_dim)";
+
+ body
<< " if (bsnh[3] < half_rotary_dim) {\n"
<< " let batch_idx = bsnh[0];\n"
<< " let sequence_idx = bsnh[1];\n"
@@ -89,46 +162,51 @@ Status FusedQKRotaryEmbeddingProgram::GenerateShaderCode(ShaderHelper& shader) c
<< " let seqlen = u32(seqlen_i);\n"
<< " let total_seqlen = seqlen + 1u;\n"
<< " let past_seqlen = total_seqlen - uniforms.q_global_shape[1];\n"
- // position_id is derived from past_seqlen + sequence_idx (always non-negative).
<< " let position_id = past_seqlen + sequence_idx;\n"
<< " let qi = dot(bsnh, uniforms.q_input_output_stride) + select(0u, bsnh[3], " << interleaved_str << ");\n"
<< " let qj = qi + select(half_rotary_dim, 1u, " << interleaved_str << ");\n"
- // Bounds check: position_id must be within cos/sin cache range.
- // On OOB, pass through input unchanged (same as CUDA kernel behavior).
- " let max_position = uniforms.cos_cache_shape[0];\n"
- " if (position_id >= max_position) {\n"
- << " " << q_output.SetByOffset("qi", q_input.GetByOffset("qi")) << "\n"
- << " " << q_output.SetByOffset("qj", q_input.GetByOffset("qj")) << "\n"
- << " if (bsnh[2] < uniforms.k_global_shape[2]) {\n"
+ << " let q_at_qi = " << load_q("qi", c_i) << ";\n"
+ << " let q_at_qj = " << load_q("qj", c_j) << ";\n"
+ << " let max_position = uniforms.cos_cache_shape[0];\n"
+ << " if (position_id >= max_position) {\n"
+ // Bounds check: position_id must be within cos/sin cache range.
+ // On OOB, pass through input (norm-applied if has_qk_norm_) unchanged.
+ << " " << q_output.SetByOffset("qi", "q_at_qi") << "\n"
+ << " " << q_output.SetByOffset("qj", "q_at_qj") << "\n"
+ << " if (needs_k) {\n"
<< " let ki = dot(bsnh, uniforms.k_input_output_stride) + select(0u, bsnh[3], " << interleaved_str << ");\n"
<< " let kj = ki + select(half_rotary_dim, 1u, " << interleaved_str << ");\n"
- << " " << k_output.SetByOffset("ki", k_input.GetByOffset("ki")) << "\n"
- << " " << k_output.SetByOffset("kj", k_input.GetByOffset("kj")) << "\n"
- " }\n"
- " } else {\n"
+ << " let k_at_ki = " << load_k("ki", c_i) << ";\n"
+ << " let k_at_kj = " << load_k("kj", c_j) << ";\n"
+ << " " << k_output.SetByOffset("ki", "k_at_ki") << "\n"
+ << " " << k_output.SetByOffset("kj", "k_at_kj") << "\n"
+ << " }\n"
+ << " } else {\n"
<< " let cos_v = " << cos_cache.GetByIndices("vec2(position_id, bsnh[3])") << ";\n"
<< " let sin_v = " << sin_cache.GetByIndices("vec2(position_id, bsnh[3])") << ";\n"
- << " let q_re = " << q_input.GetByOffset("qi") << " * cos_v - " << q_input.GetByOffset("qj") << " * sin_v;\n"
+ << " let q_re = q_at_qi * cos_v - q_at_qj * sin_v;\n"
<< " " << q_output.SetByOffset("qi", "q_re") << "\n"
- << " let q_im = " << q_input.GetByOffset("qi") << " * sin_v + " << q_input.GetByOffset("qj") << " * cos_v;\n"
+ << " let q_im = q_at_qi * sin_v + q_at_qj * cos_v;\n"
<< " " << q_output.SetByOffset("qj", "q_im") << "\n"
- // Conditionally process Key (only for heads that exist in K domain)
- << " if (bsnh[2] < uniforms.k_global_shape[2]) {\n"
+ << " if (needs_k) {\n"
<< " let ki = dot(bsnh, uniforms.k_input_output_stride) + select(0u, bsnh[3], " << interleaved_str << ");\n"
<< " let kj = ki + select(half_rotary_dim, 1u, " << interleaved_str << ");\n"
- << " let k_re = " << k_input.GetByOffset("ki") << " * cos_v - " << k_input.GetByOffset("kj") << " * sin_v;\n"
+ << " let k_at_ki = " << load_k("ki", c_i) << ";\n"
+ << " let k_at_kj = " << load_k("kj", c_j) << ";\n"
+ << " let k_re = k_at_ki * cos_v - k_at_kj * sin_v;\n"
<< " " << k_output.SetByOffset("ki", "k_re") << "\n"
- << " let k_im = " << k_input.GetByOffset("ki") << " * sin_v + " << k_input.GetByOffset("kj") << " * cos_v;\n"
+ << " let k_im = k_at_ki * sin_v + k_at_kj * cos_v;\n"
<< " " << k_output.SetByOffset("kj", "k_im") << "\n"
- " }\n"
- " }\n"
+ << " }\n"
+ << " }\n"
<< " } else {\n"
<< " let qk = dot(bsnh, uniforms.q_input_output_stride) + half_rotary_dim;\n"
- << " " << q_output.SetByOffset("qk", q_input.GetByOffset("qk")) << "\n"
- // Conditionally process Key (only for heads that exist in K domain)
- << " if (bsnh[2] < uniforms.k_global_shape[2]) {\n"
+ << " let q_at_qk = " << load_q("qk", c_k) << ";\n"
+ << " " << q_output.SetByOffset("qk", "q_at_qk") << "\n"
+ << " if (needs_k) {\n"
<< " let kk = dot(bsnh, uniforms.k_input_output_stride) + half_rotary_dim;\n"
- << " " << k_output.SetByOffset("kk", k_input.GetByOffset("kk")) << "\n"
+ << " let k_at_kk = " << load_k("kk", c_k) << ";\n"
+ << " " << k_output.SetByOffset("kk", "k_at_kk") << "\n"
<< " }\n"
<< " }\n";
return Status::OK();
diff --git a/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.h b/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.h
index e3dc4468cb3ed..dd16630e436bc 100644
--- a/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.h
+++ b/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.h
@@ -31,12 +31,20 @@ class RotaryEmbeddingProgram final : public Program {
class FusedQKRotaryEmbeddingProgram final : public Program {
public:
- FusedQKRotaryEmbeddingProgram(bool interleaved) : Program{"FusedQKRotaryEmbedding"}, interleaved_{interleaved} {}
+ FusedQKRotaryEmbeddingProgram(bool interleaved, bool has_qk_norm)
+ : Program{"FusedQKRotaryEmbedding"},
+ interleaved_{interleaved},
+ has_qk_norm_{has_qk_norm} {}
Status GenerateShaderCode(ShaderHelper& sh) const override;
// q_* describes query rotation domain (same definition as existing program)
- // k_* describes key rotation domain
+ // k_* describes key rotation domain.
+ // When has_qk_norm_ is true, the program also fuses a per-head RMS normalization
+ // (epsilon = qk_norm_epsilon, scale = q_norm_weight / k_norm_weight) over the
+ // head_size channels of Q and K before the rotary rotation. head_size and
+ // qk_norm_epsilon are required uniforms when has_qk_norm_ is true; they are
+ // ignored otherwise but must still be supplied (callers pass placeholder values).
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES(
{"scale", ProgramUniformVariableDataType::Float32},
{"q_global_shape", ProgramUniformVariableDataType::Uint32},
@@ -44,10 +52,13 @@ class FusedQKRotaryEmbeddingProgram final : public Program SimplifiedLayerNormalization -> Reshape stack; downstream graph fusion "
+ "folds that pattern into this input. Currently honored by the native WebGPU execution provider only; "
+ "JSEP WebGPU/JS and other EPs must reject the node when this input is set.",
+ "T",
+ OpSchema::Optional)
+ .Input(15,
+ "k_norm_weight",
+ "Optional 1D tensor of shape (head_size). See q_norm_weight. Must be provided together with q_norm_weight.",
+ "T",
+ OpSchema::Optional)
.Output(0,
"output",
"3D output tensor with shape (batch_size, sequence_length, hidden_size)",
diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc
index 702f20a96dccf..2f9a083a9b1be 100644
--- a/onnxruntime/core/optimizer/graph_transformer_utils.cc
+++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc
@@ -55,6 +55,7 @@
#include "core/optimizer/layer_norm_fusion.h"
#include "core/optimizer/matmul_activation_fusion.h"
#include "core/optimizer/matmul_add_fusion.h"
+#include "core/optimizer/group_query_attention_pre_norm_fusion.h"
#include "core/optimizer/matmul_bn_fusion.h"
#include "core/optimizer/matmul_integer_to_float.h"
#include "core/optimizer/matmul_scale_fusion.h"
@@ -446,6 +447,8 @@ InlinedVector> GenerateTransformers(
#endif
transformers.emplace_back(std::make_unique(cpu_ep));
+ transformers.emplace_back(std::make_unique(
+ InlinedHashSet{onnxruntime::kWebGpuExecutionProvider}));
#endif // !defined(DISABLE_CONTRIB_OPS)
// The QDQFinalCleanupTransformer must run AFTER other transformers that fuse Q/DQ nodes. Otherwise, their
diff --git a/onnxruntime/core/optimizer/group_query_attention_pre_norm_fusion.cc b/onnxruntime/core/optimizer/group_query_attention_pre_norm_fusion.cc
new file mode 100644
index 0000000000000..909229822f134
--- /dev/null
+++ b/onnxruntime/core/optimizer/group_query_attention_pre_norm_fusion.cc
@@ -0,0 +1,403 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "core/optimizer/group_query_attention_pre_norm_fusion.h"
+
+#include
+#include
+#include
+#include
+
+#include "core/graph/graph_utils.h"
+#include "core/graph/node_attr_utils.h"
+#include "core/optimizer/initializer.h"
+#include "core/optimizer/utils.h"
+
+namespace onnxruntime {
+
+namespace {
+
+constexpr const char* kQkNormEpsilonAttrName = "qk_norm_epsilon";
+constexpr float kEpsilonTolerance = 1e-9f;
+
+bool HasInput(const Node& node, size_t index) {
+ return index < node.InputDefs().size() && node.InputDefs()[index] != nullptr &&
+ !node.InputDefs()[index]->Name().empty();
+}
+
+bool HasProducedOutput(const Node& node, size_t index) {
+ return index < node.OutputDefs().size() && node.OutputDefs()[index] != nullptr &&
+ !node.OutputDefs()[index]->Name().empty();
+}
+
+bool IsGraphOutput(const Graph& graph, const NodeArg* arg) {
+ if (arg == nullptr || arg->Name().empty()) {
+ return false;
+ }
+ for (const auto* graph_output : graph.GetOutputs()) {
+ if (graph_output != nullptr && graph_output->Name() == arg->Name()) {
+ return true;
+ }
+ }
+ return false;
+}
+
+// Walks back from `consumer` via input slot `consumer_input_index` and matches:
+// producer_proj -> Reshape(reshape_inner) -> SimplifiedLayerNormalization(sln) -> Reshape(reshape_outer) -> consumer
+// (`reshape_inner` is the one closest to the projection: it reshapes the (batch, seq, hidden)
+// tensor to (batch, seq, num_heads, head_size). `reshape_outer` is the one closest to the
+// consumer: it folds back to (batch, seq, hidden).)
+// On success returns true and fills the out-pointers. Each intermediate node must have a single
+// consumer (the next op in the chain) and must not be a graph output.
+bool MatchPreNormReshapeChain(Graph& graph,
+ Node& consumer,
+ int consumer_input_index,
+ int64_t expected_head_size,
+ int64_t expected_hidden_size,
+ Node*& reshape_outer_out,
+ Node*& sln_out,
+ Node*& reshape_inner_out,
+ NodeArg*& projection_arg_out,
+ NodeArg*& norm_weight_arg_out,
+ float& epsilon_out) {
+ reshape_outer_out = nullptr;
+ sln_out = nullptr;
+ reshape_inner_out = nullptr;
+ projection_arg_out = nullptr;
+ norm_weight_arg_out = nullptr;
+ epsilon_out = 0.0f;
+
+ if (consumer_input_index < 0 ||
+ static_cast(consumer_input_index) >= consumer.InputDefs().size()) {
+ return false;
+ }
+
+ NodeArg* consumer_input = consumer.MutableInputDefs()[consumer_input_index];
+ if (consumer_input == nullptr || consumer_input->Name().empty()) {
+ return false;
+ }
+
+ Node* reshape_outer = graph.GetMutableProducerNode(consumer_input->Name());
+ if (reshape_outer == nullptr ||
+ !graph_utils::IsSupportedOptypeVersionAndDomain(*reshape_outer, "Reshape", {5, 13, 14, 19, 21, 23})) {
+ return false;
+ }
+ if (reshape_outer->GetOutputEdgesCount() != 1) {
+ return false;
+ }
+ if (IsGraphOutput(graph, reshape_outer->OutputDefs()[0])) {
+ return false;
+ }
+
+ // Validate outer reshape output last dim equals hidden size (num_heads * head_size).
+ const auto* reshape_outer_shape = reshape_outer->OutputDefs()[0]->Shape();
+ if (reshape_outer_shape == nullptr || reshape_outer_shape->dim_size() < 1) {
+ return false;
+ }
+ const auto& reshape_outer_last = reshape_outer_shape->dim(reshape_outer_shape->dim_size() - 1);
+ if (!reshape_outer_last.has_dim_value() || reshape_outer_last.dim_value() != expected_hidden_size) {
+ return false;
+ }
+
+ if (reshape_outer->InputDefs().empty() || reshape_outer->InputDefs()[0] == nullptr) {
+ return false;
+ }
+ Node* sln = graph.GetMutableProducerNode(reshape_outer->InputDefs()[0]->Name());
+ if (sln == nullptr ||
+ !graph_utils::IsSupportedOptypeVersionAndDomain(*sln, "SimplifiedLayerNormalization", {1})) {
+ return false;
+ }
+ if (sln->GetOutputEdgesCount() != 1) {
+ return false;
+ }
+ if (IsGraphOutput(graph, sln->OutputDefs()[0])) {
+ return false;
+ }
+ // SLN may emit auxiliary outputs (mean / inv_std). They must not be consumed elsewhere.
+ for (size_t i = 1; i < sln->OutputDefs().size(); ++i) {
+ if (HasProducedOutput(*sln, i)) {
+ return false;
+ }
+ }
+ if (sln->InputDefs().size() < 2 || sln->InputDefs()[1] == nullptr ||
+ sln->InputDefs()[1]->Name().empty()) {
+ return false;
+ }
+
+ // SimplifiedLayerNormalization permits its input (T), scale (V) and output (T) to use different
+ // element types. The fused GroupQueryAttention input slots reuse the projection's element type
+ // (T), so we can only fuse when scale and output also use T -- otherwise the rewrite would
+ // change the node's type constraints and produce a semantically different graph. Require all
+ // three to match before fusing.
+ auto get_elem_type = [](const NodeArg* arg) -> int32_t {
+ if (arg == nullptr) {
+ return ONNX_NAMESPACE::TensorProto::UNDEFINED;
+ }
+ const auto* type_proto = arg->TypeAsProto();
+ if (type_proto == nullptr || !type_proto->has_tensor_type() ||
+ !type_proto->tensor_type().has_elem_type()) {
+ return ONNX_NAMESPACE::TensorProto::UNDEFINED;
+ }
+ return type_proto->tensor_type().elem_type();
+ };
+ const int32_t sln_input_elem_type = get_elem_type(sln->InputDefs()[0]);
+ const int32_t sln_scale_elem_type = get_elem_type(sln->InputDefs()[1]);
+ const int32_t sln_output_elem_type = get_elem_type(sln->OutputDefs()[0]);
+ if (sln_input_elem_type == ONNX_NAMESPACE::TensorProto::UNDEFINED ||
+ sln_input_elem_type != sln_scale_elem_type ||
+ sln_input_elem_type != sln_output_elem_type) {
+ return false;
+ }
+
+ // Norm weight must be an initializer of shape [head_size].
+ NodeArg* norm_weight_arg = sln->MutableInputDefs()[1];
+ const ONNX_NAMESPACE::TensorProto* norm_weight_tensor =
+ graph_utils::GetConstantInitializer(graph, norm_weight_arg->Name());
+ if (norm_weight_tensor == nullptr) {
+ return false;
+ }
+ if (norm_weight_tensor->dims_size() != 1 || norm_weight_tensor->dims(0) != expected_head_size) {
+ return false;
+ }
+
+ const auto* sln_axis_attr = graph_utils::GetNodeAttribute(*sln, "axis");
+ const int64_t sln_axis = (sln_axis_attr == nullptr) ? -1 : sln_axis_attr->i();
+ if (sln_axis != -1) {
+ return false;
+ }
+ const auto* sln_eps_attr = graph_utils::GetNodeAttribute(*sln, "epsilon");
+ const float sln_eps = (sln_eps_attr == nullptr) ? 1e-5f : sln_eps_attr->f();
+
+ // Inner reshape (between projection and SLN).
+ if (sln->InputDefs().empty() || sln->InputDefs()[0] == nullptr) {
+ return false;
+ }
+ Node* reshape_inner = graph.GetMutableProducerNode(sln->InputDefs()[0]->Name());
+ if (reshape_inner == nullptr ||
+ !graph_utils::IsSupportedOptypeVersionAndDomain(*reshape_inner, "Reshape", {5, 13, 14, 19, 21, 23})) {
+ return false;
+ }
+ if (reshape_inner->GetOutputEdgesCount() != 1) {
+ return false;
+ }
+ if (IsGraphOutput(graph, reshape_inner->OutputDefs()[0])) {
+ return false;
+ }
+ const auto* reshape_inner_shape = reshape_inner->OutputDefs()[0]->Shape();
+ if (reshape_inner_shape == nullptr || reshape_inner_shape->dim_size() < 1) {
+ return false;
+ }
+ const auto& reshape_inner_last = reshape_inner_shape->dim(reshape_inner_shape->dim_size() - 1);
+ if (!reshape_inner_last.has_dim_value() || reshape_inner_last.dim_value() != expected_head_size) {
+ return false;
+ }
+
+ if (reshape_inner->InputDefs().empty() || reshape_inner->InputDefs()[0] == nullptr) {
+ return false;
+ }
+
+ reshape_outer_out = reshape_outer;
+ sln_out = sln;
+ reshape_inner_out = reshape_inner;
+ projection_arg_out = reshape_inner->MutableInputDefs()[0];
+ norm_weight_arg_out = norm_weight_arg;
+ epsilon_out = sln_eps;
+ return true;
+}
+
+} // namespace
+
+Status GroupQueryAttentionPreNormFusion::ApplyImpl(Graph& graph,
+ bool& modified,
+ int graph_level,
+ const logging::Logger& logger) const {
+ GraphViewer graph_viewer(graph);
+ const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
+
+ for (auto node_index : node_topology_list) {
+ auto* node_ptr = graph.GetNode(node_index);
+ if (node_ptr == nullptr) {
+ continue;
+ }
+ Node& node = *node_ptr;
+ ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger));
+
+ if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "GroupQueryAttention", {1}, kMSDomain) ||
+ !graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders())) {
+ continue;
+ }
+
+ // Already fused?
+ if (HasInput(node, 14) || HasInput(node, 15)) {
+ continue;
+ }
+
+ // Need at least query (0), key (1), value (2), past_key (3) so we can read head_size.
+ // Requiring K at slot 1 also excludes the packed-QKV form (Q occupies slot 0 and K/V
+ // slots are empty), which the WebGPU fused prologue does not support.
+ if (node.InputDefs().size() < 4 || !HasInput(node, 0) || !HasInput(node, 1) || !HasInput(node, 2)) {
+ continue;
+ }
+
+ // The fused decode prologue only applies when rotary embedding is enabled (Qwen3-style
+ // configuration). If the GQA node has do_rotary=0 the kernel will reject the rewritten
+ // node, so skip the fusion here to avoid that regression.
+ const auto& gqa_attrs = node.GetAttributes();
+ auto do_rotary_it = gqa_attrs.find("do_rotary");
+ const int64_t do_rotary = (do_rotary_it == gqa_attrs.end()) ? 0 : do_rotary_it->second.i();
+ if (do_rotary != 1) {
+ continue;
+ }
+ const NodeArg* past_key_arg = node.InputDefs()[3];
+ if (past_key_arg == nullptr || past_key_arg->Shape() == nullptr ||
+ past_key_arg->Shape()->dim_size() < 4) {
+ continue;
+ }
+ const auto& head_size_dim = past_key_arg->Shape()->dim(3);
+ if (!head_size_dim.has_dim_value()) {
+ continue;
+ }
+ const int64_t head_size = head_size_dim.dim_value();
+
+ auto num_heads_it = gqa_attrs.find("num_heads");
+ auto kv_num_heads_it = gqa_attrs.find("kv_num_heads");
+ if (num_heads_it == gqa_attrs.end() || kv_num_heads_it == gqa_attrs.end()) {
+ continue;
+ }
+ const int64_t num_heads = num_heads_it->second.i();
+ const int64_t kv_num_heads = kv_num_heads_it->second.i();
+ const int64_t q_hidden_size = num_heads * head_size;
+ const int64_t kv_hidden_size = kv_num_heads * head_size;
+
+ // Match pre-norm Reshape -> SLN -> Reshape on Q (slot 0) and K (slot 1).
+ Node* q_reshape_outer = nullptr;
+ Node* q_sln = nullptr;
+ Node* q_reshape_inner = nullptr;
+ NodeArg* q_projection_arg = nullptr;
+ NodeArg* q_norm_weight_arg = nullptr;
+ float q_epsilon = 0.0f;
+ if (!MatchPreNormReshapeChain(graph, node, /*consumer_input_index=*/0, head_size, q_hidden_size,
+ q_reshape_outer, q_sln, q_reshape_inner,
+ q_projection_arg, q_norm_weight_arg, q_epsilon)) {
+ continue;
+ }
+
+ Node* k_reshape_outer = nullptr;
+ Node* k_sln = nullptr;
+ Node* k_reshape_inner = nullptr;
+ NodeArg* k_projection_arg = nullptr;
+ NodeArg* k_norm_weight_arg = nullptr;
+ float k_epsilon = 0.0f;
+ if (!MatchPreNormReshapeChain(graph, node, /*consumer_input_index=*/1, head_size, kv_hidden_size,
+ k_reshape_outer, k_sln, k_reshape_inner,
+ k_projection_arg, k_norm_weight_arg, k_epsilon)) {
+ continue;
+ }
+
+ if (std::fabs(q_epsilon - k_epsilon) > kEpsilonTolerance) {
+ continue;
+ }
+
+ LOGS(logger, VERBOSE) << "GroupQueryAttentionPreNormFusion: matched gqa='" << node.Name()
+ << "' q_sln='" << q_sln->Name() << "' k_sln='" << k_sln->Name()
+ << "' head_size=" << head_size
+ << " num_heads=" << num_heads << " kv_num_heads=" << kv_num_heads
+ << " epsilon=" << q_epsilon;
+
+ // Build new GQA inputs: copy existing inputs, replace 0/1 with projection outputs,
+ // pad up to slot 13 with empty NodeArgs, then add q/k norm weights at 14/15.
+ NodeArg& empty_arg = graph.GetOrCreateNodeArg("", nullptr);
+ InlinedVector new_inputs;
+ new_inputs.reserve(16);
+ for (size_t i = 0; i < 16; ++i) {
+ if (i == 0) {
+ new_inputs.push_back(q_projection_arg);
+ } else if (i == 1) {
+ new_inputs.push_back(k_projection_arg);
+ } else if (i == 14) {
+ new_inputs.push_back(q_norm_weight_arg);
+ } else if (i == 15) {
+ new_inputs.push_back(k_norm_weight_arg);
+ } else if (i < node.InputDefs().size()) {
+ NodeArg* existing = node.MutableInputDefs()[i];
+ new_inputs.push_back((existing != nullptr && !existing->Name().empty()) ? existing : &empty_arg);
+ } else {
+ new_inputs.push_back(&empty_arg);
+ }
+ }
+
+ // Outputs: keep the same NodeArgs so downstream consumers and graph outputs are preserved.
+ InlinedVector new_outputs;
+ new_outputs.reserve(node.OutputDefs().size());
+ for (auto* out : node.OutputDefs()) {
+ new_outputs.push_back(const_cast(out));
+ }
+
+ // Copy attributes and add qk_norm_epsilon.
+ NodeAttributes new_attrs = node.GetAttributes();
+ utils::SetNodeAttribute(utils::MakeAttribute(std::string(kQkNormEpsilonAttrName), q_epsilon), new_attrs);
+
+ const std::string original_name = node.Name();
+ const std::string original_ep = node.GetExecutionProviderType();
+
+ // Snapshot the GQA's original input edges (we will rewire them, except for slots 0/1).
+ auto gqa_input_edges = graph_utils::GraphEdge::GetNodeInputEdges(node);
+ auto gqa_output_edges = graph_utils::GraphEdge::GetNodeOutputEdges(node);
+
+ // Remove all involved nodes (their input edges from elsewhere drop with them).
+ graph_utils::RemoveNodeOutputEdges(graph, node);
+ graph.RemoveNode(node.Index());
+ graph_utils::RemoveNodeOutputEdges(graph, *q_reshape_outer);
+ graph.RemoveNode(q_reshape_outer->Index());
+ graph_utils::RemoveNodeOutputEdges(graph, *q_sln);
+ graph.RemoveNode(q_sln->Index());
+ graph_utils::RemoveNodeOutputEdges(graph, *q_reshape_inner);
+ graph.RemoveNode(q_reshape_inner->Index());
+ graph_utils::RemoveNodeOutputEdges(graph, *k_reshape_outer);
+ graph.RemoveNode(k_reshape_outer->Index());
+ graph_utils::RemoveNodeOutputEdges(graph, *k_sln);
+ graph.RemoveNode(k_sln->Index());
+ graph_utils::RemoveNodeOutputEdges(graph, *k_reshape_inner);
+ graph.RemoveNode(k_reshape_inner->Index());
+
+ Node& fused = graph.AddNode(graph.GenerateNodeName(original_name + "_qknorm"),
+ "GroupQueryAttention",
+ "GroupQueryAttention with fused per-head Q/K RMSNorm",
+ new_inputs,
+ new_outputs,
+ &new_attrs,
+ kMSDomain);
+ fused.SetExecutionProviderType(original_ep);
+
+ // Rewire upstream edges that fed the original GQA. Skip slots 0 and 1 (now driven by
+ // the projection outputs which are still produced by their upstream nodes; the
+ // graph.AddNode + matching NodeArg name will let the graph's edge resolver re-attach
+ // those producer edges automatically when Resolve() runs, but we add them explicitly
+ // for safety).
+ for (const auto& e : gqa_input_edges) {
+ if (e.dst_arg_index == 0 || e.dst_arg_index == 1) {
+ continue;
+ }
+ graph.AddEdge(e.src_node, fused.Index(), e.src_arg_index, e.dst_arg_index);
+ }
+ // Add explicit edges for the new query/key inputs from the projection nodes.
+ if (Node* q_proj_node = graph.GetMutableProducerNode(q_projection_arg->Name())) {
+ const int src_idx = graph_utils::GetNodeOutputIndexFromOutputName(*q_proj_node, q_projection_arg->Name());
+ graph.AddEdge(q_proj_node->Index(), fused.Index(), src_idx, 0);
+ }
+ if (Node* k_proj_node = graph.GetMutableProducerNode(k_projection_arg->Name())) {
+ const int src_idx = graph_utils::GetNodeOutputIndexFromOutputName(*k_proj_node, k_projection_arg->Name());
+ graph.AddEdge(k_proj_node->Index(), fused.Index(), src_idx, 1);
+ }
+ // Rewire downstream edges from the original GQA outputs.
+ for (const auto& e : gqa_output_edges) {
+ graph.AddEdge(fused.Index(), e.dst_node, e.src_arg_index, e.dst_arg_index);
+ }
+
+ modified = true;
+ }
+
+ return Status::OK();
+}
+
+} // namespace onnxruntime
diff --git a/onnxruntime/core/optimizer/group_query_attention_pre_norm_fusion.h b/onnxruntime/core/optimizer/group_query_attention_pre_norm_fusion.h
new file mode 100644
index 0000000000000..b69199bb5324d
--- /dev/null
+++ b/onnxruntime/core/optimizer/group_query_attention_pre_norm_fusion.h
@@ -0,0 +1,47 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+
+#include "core/optimizer/graph_transformer.h"
+
+namespace onnxruntime {
+
+/**
+@Class GroupQueryAttentionPreNormFusion
+
+Folds the Qwen3-style per-head Q/K RMSNorm prologue into the GroupQueryAttention
+node by adding optional q_norm_weight and k_norm_weight inputs (slots 14 and 15)
+and a qk_norm_epsilon attribute. The transform looks for the following pattern
+on inputs 0 (query) and 1 (key) of an unfused GroupQueryAttention node:
+
+ Q_proj_out -> Reshape[*,*,head_size]
+ -> SimplifiedLayerNormalization(weight = q_norm_weight)
+ -> Reshape[*,*,num_heads * head_size]
+ -> GQA[input 0]
+
+ K_proj_out -> Reshape[*,*,head_size]
+ -> SimplifiedLayerNormalization(weight = k_norm_weight)
+ -> Reshape[*,*,kv_num_heads * head_size]
+ -> GQA[input 1]
+
+When matched, the six Reshape/SLN nodes are removed and the pre-norm Q and K
+projections feed GQA directly. The kernel is responsible for applying the RMS
+norm internally (currently the WebGPU EP).
+
+Only fires for execution providers passed in `compatible_execution_providers`.
+At present this fusion is registered for the WebGPU EP only, because the
+in-kernel norm path is currently implemented there. The CPU, CUDA, and JSEP
+GroupQueryAttention kernels reject q_norm_weight / k_norm_weight inputs.
+*/
+class GroupQueryAttentionPreNormFusion : public GraphTransformer {
+ public:
+ explicit GroupQueryAttentionPreNormFusion(
+ const InlinedHashSet& compatible_execution_providers = {}) noexcept
+ : GraphTransformer("GroupQueryAttentionPreNormFusion", compatible_execution_providers) {
+ }
+
+ Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
+};
+
+} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h
index 636f185eda422..39cce7941a15b 100644
--- a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h
+++ b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h
@@ -159,4 +159,4 @@ constexpr const char GeluExpr[] = "0.5 * a * (1.0 + erf_v(a * 0.7071067811865475
constexpr const char FastGeluExpr[] = "a * (0.5 + 0.5 * tanh_v(a * (0.035677408136300125 * a * a + 0.7978845608028654)))";
} // namespace webgpu
-} // namespace onnxruntime
\ No newline at end of file
+} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/webgpu/nn/layer_norm.cc b/onnxruntime/core/providers/webgpu/nn/layer_norm.cc
index 7d4ae8c2197ff..3afedea30adaf 100644
--- a/onnxruntime/core/providers/webgpu/nn/layer_norm.cc
+++ b/onnxruntime/core/providers/webgpu/nn/layer_norm.cc
@@ -162,8 +162,6 @@ Status LayerNorm::ComputeInternal(onnxruntime::webgpu::ComputeContex
const size_t axis = NormalizeAxis(axis_, x_shape.NumDimensions());
const uint32_t norm_count = onnxruntime::narrow(x_shape.SizeToDimension(axis));
const int64_t norm_size = x_shape.SizeFromDimension(axis);
- const int components = GetMaxComponents(norm_size);
- const uint32_t norm_size_vectorized = onnxruntime::narrow((norm_size + components - 1) / components);
const auto scale_size = scale->Shape().Size();
const auto bias_size = (bias) ? bias->Shape().Size() : 0;
@@ -192,6 +190,28 @@ Status LayerNorm::ComputeInternal(onnxruntime::webgpu::ComputeContex
return Status::OK();
}
+ return RunLayerNormProgram(context, x, scale, bias, epsilon_, norm_count, norm_size,
+ simplified, y, mean, inv_std_dev);
+}
+
+Status RunLayerNormProgram(ComputeContext& context,
+ const Tensor* x,
+ const Tensor* scale,
+ const Tensor* bias,
+ float epsilon,
+ uint32_t norm_count,
+ int64_t norm_size,
+ bool simplified,
+ Tensor* y,
+ Tensor* mean,
+ Tensor* inv_std_dev) {
+ if (x->Shape().Size() == 0) {
+ return Status::OK();
+ }
+
+ const int components = GetMaxComponents(norm_size);
+ const uint32_t norm_size_vectorized = onnxruntime::narrow((norm_size + components - 1) / components);
+
// Check if we should use split norm dimension optimization
const bool split_norm_dim = norm_size % 512 == 0 && norm_count == 1;
@@ -215,7 +235,7 @@ Status LayerNorm::ComputeInternal(onnxruntime::webgpu::ComputeContex
{static_cast(norm_size_vectorized)},
})
.AddUniformVariables({
- {static_cast(epsilon_)},
+ {static_cast(epsilon)},
});
if (split_norm_dim) {
diff --git a/onnxruntime/core/providers/webgpu/nn/layer_norm.h b/onnxruntime/core/providers/webgpu/nn/layer_norm.h
index 112b152d37130..a6323dc7721d4 100644
--- a/onnxruntime/core/providers/webgpu/nn/layer_norm.h
+++ b/onnxruntime/core/providers/webgpu/nn/layer_norm.h
@@ -56,5 +56,21 @@ class LayerNorm final : public WebGpuKernel {
int64_t stash_type_;
};
+// Configures and dispatches a LayerNormProgram. Centralizes the program-setup logic
+// (uniform variables, components, split_norm_dim heuristic, workgroup sizing) so callers
+// other than the LayerNorm kernel (e.g. fused MatMulNBits ops) do not need to duplicate it.
+// `bias`, `mean` and `inv_std_dev` may be nullptr.
+Status RunLayerNormProgram(ComputeContext& context,
+ const Tensor* x,
+ const Tensor* scale,
+ const Tensor* bias,
+ float epsilon,
+ uint32_t norm_count,
+ int64_t norm_size,
+ bool simplified,
+ Tensor* y,
+ Tensor* mean,
+ Tensor* inv_std_dev);
+
} // namespace webgpu
} // namespace onnxruntime
diff --git a/onnxruntime/test/optimizer/group_query_attention_pre_norm_fusion_test.cc b/onnxruntime/test/optimizer/group_query_attention_pre_norm_fusion_test.cc
new file mode 100644
index 0000000000000..4f3deb2306f38
--- /dev/null
+++ b/onnxruntime/test/optimizer/group_query_attention_pre_norm_fusion_test.cc
@@ -0,0 +1,333 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include
+#include
+
+#include "core/graph/node_attr_utils.h"
+#include "core/optimizer/group_query_attention_pre_norm_fusion.h"
+#include "core/optimizer/utils.h"
+
+#include "test/util/include/asserts.h"
+#include "test/unittest_util/framework_test_utils.h"
+#include "test/unittest_util/graph_transform_test_builder.h"
+#include "test/optimizer/graph_transform_test_fixture.h"
+
+#include "gtest/gtest.h"
+
+namespace onnxruntime {
+namespace test {
+
+#if !defined(DISABLE_CONTRIB_OPS)
+
+namespace {
+
+// Small geometry that exercises the Q/K post-norm pattern without needing real GPU work.
+constexpr int64_t kBatch = 1;
+constexpr int64_t kSeq = 1;
+constexpr int64_t kNumHeads = 2;
+constexpr int64_t kKvNumHeads = 1;
+constexpr int64_t kHeadSize = 4;
+constexpr int64_t kQHidden = kNumHeads * kHeadSize;
+constexpr int64_t kKvHidden = kKvNumHeads * kHeadSize;
+constexpr int64_t kMaxSeq = 8;
+
+void SetWebGpu(Node& node) { node.SetExecutionProviderType(kWebGpuExecutionProvider); }
+
+// Builds: [Reshape -> SimplifiedLayerNormalization -> Reshape] on Q and K, feeding a
+// GroupQueryAttention node. V goes straight into GQA. The pattern is configured via
+// BuildOptions so individual tests can flip a single attribute / shape / epsilon to
+// exercise each gate.
+struct BuildOptions {
+ float q_epsilon = 1e-6f;
+ float k_epsilon = 1e-6f;
+ // If true, the inner reshape on the K side targets a different last-dim than head_size
+ // so the matcher must reject it.
+ bool break_k_inner_reshape_shape = false;
+ // If true, the q_norm_weight initializer is given a non-1D shape so the matcher must
+ // reject it.
+ bool break_q_norm_weight_shape = false;
+ // GQA do_rotary attribute. The WebGPU fused prologue only supports do_rotary=1, so the
+ // optimizer must skip the rewrite when this is 0.
+ int64_t do_rotary = 1;
+ // If true, drop the K input from the GQA node (slot 1 empty), simulating the packed-QKV
+ // form. The optimizer must skip the rewrite in that case.
+ bool packed_qkv = false;
+ // If true, pre-populate the GQA node's slot 14 with a q_norm_weight initializer so the
+ // optimizer treats the node as already fused and skips it.
+ bool pre_fused = false;
+};
+
+void BuildQwenQkPostNormPattern(ModelTestBuilder& builder, const BuildOptions& opts) {
+ // Projection inputs (post linear projection, pre norm).
+ NodeArg* q_proj = builder.MakeInput(
+ std::vector{kBatch, kSeq, kQHidden}, MLFloat16(-1.0f), MLFloat16(1.0f));
+ NodeArg* k_proj = builder.MakeInput(
+ std::vector{kBatch, kSeq, kKvHidden}, MLFloat16(-1.0f), MLFloat16(1.0f));
+ NodeArg* v_proj = builder.MakeInput(
+ std::vector{kBatch, kSeq, kKvHidden}, MLFloat16(-1.0f), MLFloat16(1.0f));
+
+ // GQA cache + control inputs.
+ NodeArg* past_key = builder.MakeInput(
+ std::vector{kBatch, kKvNumHeads, kMaxSeq, kHeadSize}, MLFloat16(0.0f), MLFloat16(0.0f));
+ NodeArg* past_value = builder.MakeInput(
+ std::vector{kBatch, kKvNumHeads, kMaxSeq, kHeadSize}, MLFloat16(0.0f), MLFloat16(0.0f));
+ // Note: ModelTestBuilder::MakeInput(shape, min, max) calls Uniform(min, max - 1)
+ // internally, which asserts on min == max. Use the explicit-data overload instead.
+ NodeArg* seqlens_k = builder.MakeInput(std::vector{kBatch}, std::vector{0});
+ NodeArg* total_seq_len = builder.MakeInput(std::vector{1}, std::vector{1});
+
+ // Norm weight initializers: [head_size]. (Or non-1D when forcing a shape mismatch.)
+ std::vector q_norm_weight_shape =
+ opts.break_q_norm_weight_shape ? std::vector{1, kHeadSize} : std::vector{kHeadSize};
+ NodeArg* q_norm_weight = builder.MakeInitializer(q_norm_weight_shape, MLFloat16(1.0f), MLFloat16(1.0f));
+ NodeArg* k_norm_weight = builder.MakeInitializer({kHeadSize}, MLFloat16(1.0f), MLFloat16(1.0f));
+
+ // Reshape "shape" initializers.
+ NodeArg* reshape_to_per_head_q = builder.MakeInitializer({4}, {kBatch, kSeq, kNumHeads, kHeadSize});
+ const int64_t k_inner_last_dim = opts.break_k_inner_reshape_shape ? (kHeadSize * 2) : kHeadSize;
+ NodeArg* reshape_to_per_head_k =
+ builder.MakeInitializer({4}, {kBatch, kSeq, kKvNumHeads, k_inner_last_dim});
+ NodeArg* reshape_to_q_hidden = builder.MakeInitializer({3}, {kBatch, kSeq, kQHidden});
+ NodeArg* reshape_to_kv_hidden = builder.MakeInitializer({3}, {kBatch, kSeq, kKvHidden});
+
+ // Q-side chain.
+ NodeArg* q_inner_reshape_out = builder.MakeIntermediate(
+ std::vector{kBatch, kSeq, kNumHeads, kHeadSize});
+ NodeArg* q_normed = builder.MakeIntermediate(
+ std::vector{kBatch, kSeq, kNumHeads, kHeadSize});
+ NodeArg* q_outer_reshape_out = builder.MakeIntermediate(
+ std::vector{kBatch, kSeq, kQHidden});
+
+ Node& q_inner_reshape = builder.AddNode("Reshape", {q_proj, reshape_to_per_head_q}, {q_inner_reshape_out});
+ Node& q_sln = builder.AddNode("SimplifiedLayerNormalization", {q_inner_reshape_out, q_norm_weight}, {q_normed});
+ q_sln.AddAttribute("axis", static_cast(-1));
+ q_sln.AddAttribute("epsilon", opts.q_epsilon);
+ Node& q_outer_reshape = builder.AddNode("Reshape", {q_normed, reshape_to_q_hidden}, {q_outer_reshape_out});
+
+ // K-side chain.
+ NodeArg* k_inner_reshape_out = builder.MakeIntermediate(
+ std::vector{kBatch, kSeq, kKvNumHeads, k_inner_last_dim});
+ NodeArg* k_normed = builder.MakeIntermediate(
+ std::vector{kBatch, kSeq, kKvNumHeads, k_inner_last_dim});
+ NodeArg* k_outer_reshape_out = builder.MakeIntermediate(
+ std::vector{kBatch, kSeq, kKvHidden});
+
+ Node& k_inner_reshape = builder.AddNode("Reshape", {k_proj, reshape_to_per_head_k}, {k_inner_reshape_out});
+ Node& k_sln = builder.AddNode("SimplifiedLayerNormalization", {k_inner_reshape_out, k_norm_weight}, {k_normed});
+ k_sln.AddAttribute("axis", static_cast(-1));
+ k_sln.AddAttribute("epsilon", opts.k_epsilon);
+ Node& k_outer_reshape = builder.AddNode("Reshape", {k_normed, reshape_to_kv_hidden}, {k_outer_reshape_out});
+
+ // GQA outputs.
+ NodeArg* gqa_out = builder.MakeOutput(std::vector{kBatch, kSeq, kQHidden});
+ NodeArg* present_key = builder.MakeOutput(
+ std::vector{kBatch, kKvNumHeads, kMaxSeq, kHeadSize});
+ NodeArg* present_value = builder.MakeOutput(
+ std::vector{kBatch, kKvNumHeads, kMaxSeq, kHeadSize});
+
+ // Build the GQA input list. The packed_qkv variant drops K (slot 1) and V (slot 2).
+ NodeArg& empty_arg = builder.graph_.GetOrCreateNodeArg("", nullptr);
+ std::vector gqa_inputs;
+ gqa_inputs.push_back(q_outer_reshape_out);
+ gqa_inputs.push_back(opts.packed_qkv ? &empty_arg : k_outer_reshape_out);
+ gqa_inputs.push_back(opts.packed_qkv ? &empty_arg : v_proj);
+ gqa_inputs.push_back(past_key);
+ gqa_inputs.push_back(past_value);
+ gqa_inputs.push_back(seqlens_k);
+ gqa_inputs.push_back(total_seq_len);
+
+ if (opts.pre_fused) {
+ // Pad slots 7..13 with empty args, then place a real norm weight in slot 14.
+ for (int i = 7; i < 14; ++i) {
+ gqa_inputs.push_back(&empty_arg);
+ }
+ NodeArg* preexisting_q_norm =
+ builder.MakeInitializer({kHeadSize}, MLFloat16(1.0f), MLFloat16(1.0f));
+ gqa_inputs.push_back(preexisting_q_norm);
+ }
+
+ Node& gqa = builder.AddNode("GroupQueryAttention",
+ gqa_inputs,
+ {gqa_out, present_key, present_value},
+ kMSDomain);
+ gqa.AddAttribute("num_heads", static_cast(kNumHeads));
+ gqa.AddAttribute("kv_num_heads", static_cast(kKvNumHeads));
+ gqa.AddAttribute("do_rotary", opts.do_rotary);
+
+ SetWebGpu(q_inner_reshape);
+ SetWebGpu(q_sln);
+ SetWebGpu(q_outer_reshape);
+ SetWebGpu(k_inner_reshape);
+ SetWebGpu(k_sln);
+ SetWebGpu(k_outer_reshape);
+ SetWebGpu(gqa);
+}
+
+Status CheckFusedGraph(Graph& graph) {
+ const auto op_to_count = CountOpsInGraph(graph);
+ if (OpCount(op_to_count, "com.microsoft.GroupQueryAttention") != 1 ||
+ OpCount(op_to_count, "SimplifiedLayerNormalization") != 0 ||
+ OpCount(op_to_count, "Reshape") != 0) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
+ "Unexpected op counts after GroupQueryAttentionPreNormFusion: ",
+ "GQA=", OpCount(op_to_count, "com.microsoft.GroupQueryAttention"),
+ " SLN=", OpCount(op_to_count, "SimplifiedLayerNormalization"),
+ " Reshape=", OpCount(op_to_count, "Reshape"));
+ }
+
+ for (const auto& node : graph.Nodes()) {
+ if (node.OpType() != "GroupQueryAttention") continue;
+ ORT_RETURN_IF_NOT(node.InputDefs().size() >= 16, "Fused GQA must expose 16 input slots.");
+
+ const auto* q_norm = node.InputDefs()[14];
+ const auto* k_norm = node.InputDefs()[15];
+ ORT_RETURN_IF_NOT(q_norm != nullptr && q_norm->Exists(), "q_norm_weight (slot 14) missing.");
+ ORT_RETURN_IF_NOT(k_norm != nullptr && k_norm->Exists(), "k_norm_weight (slot 15) missing.");
+
+ const auto& attrs = node.GetAttributes();
+ auto eps_it = attrs.find("qk_norm_epsilon");
+ ORT_RETURN_IF_NOT(eps_it != attrs.end(), "qk_norm_epsilon attribute missing.");
+ ORT_RETURN_IF_NOT(std::abs(eps_it->second.f() - 1e-6f) < 1e-9f, "qk_norm_epsilon value mismatch.");
+ }
+ return Status::OK();
+}
+
+Status CheckUnfusedGraph(Graph& graph) {
+ const auto op_to_count = CountOpsInGraph(graph);
+ if (OpCount(op_to_count, "com.microsoft.GroupQueryAttention") != 1 ||
+ OpCount(op_to_count, "SimplifiedLayerNormalization") != 2 ||
+ OpCount(op_to_count, "Reshape") != 4) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Negative test: graph was fused unexpectedly.");
+ }
+ for (const auto& node : graph.Nodes()) {
+ if (node.OpType() != "GroupQueryAttention") continue;
+ if (node.InputDefs().size() >= 15) {
+ const auto* q_norm = node.InputDefs()[14];
+ ORT_RETURN_IF_NOT(q_norm == nullptr || !q_norm->Exists(),
+ "Negative test: q_norm_weight should not be wired.");
+ }
+ }
+ return Status::OK();
+}
+
+} // namespace
+
+// Helper: build the transformer registered for the WebGPU EP only (matches production).
+std::unique_ptr MakeWebGpuTransformer() {
+ return std::make_unique(
+ InlinedHashSet{kWebGpuExecutionProvider});
+}
+
+TEST_F(GraphTransformationTests, GroupQueryAttentionPreNormFusionFusesQwenPattern) {
+ auto build = [](ModelTestBuilder& builder) { BuildQwenQkPostNormPattern(builder, BuildOptions{}); };
+ ASSERT_STATUS_OK(TestGraphTransformer(
+ build, /*opset_version=*/21, *logger_, MakeWebGpuTransformer(),
+ TransformerLevel::Level2, /*steps=*/1, nullptr, CheckFusedGraph));
+}
+
+TEST_F(GraphTransformationTests, GroupQueryAttentionPreNormFusionRejectsEpsilonMismatch) {
+ BuildOptions opts;
+ opts.q_epsilon = 1e-6f;
+ opts.k_epsilon = 1e-5f;
+ auto build = [opts](ModelTestBuilder& builder) { BuildQwenQkPostNormPattern(builder, opts); };
+ ASSERT_STATUS_OK(TestGraphTransformer(
+ build, /*opset_version=*/21, *logger_, MakeWebGpuTransformer(),
+ TransformerLevel::Level2, /*steps=*/1, nullptr, CheckUnfusedGraph));
+}
+
+TEST_F(GraphTransformationTests, GroupQueryAttentionPreNormFusionRejectsBadInnerReshape) {
+ BuildOptions opts;
+ opts.break_k_inner_reshape_shape = true;
+ auto build = [opts](ModelTestBuilder& builder) { BuildQwenQkPostNormPattern(builder, opts); };
+ ASSERT_STATUS_OK(TestGraphTransformer(
+ build, /*opset_version=*/21, *logger_, MakeWebGpuTransformer(),
+ TransformerLevel::Level2, /*steps=*/1, nullptr, CheckUnfusedGraph));
+}
+
+TEST_F(GraphTransformationTests, GroupQueryAttentionPreNormFusionRejectsNon1DNormWeight) {
+ BuildOptions opts;
+ opts.break_q_norm_weight_shape = true;
+ auto build = [opts](ModelTestBuilder& builder) { BuildQwenQkPostNormPattern(builder, opts); };
+ ASSERT_STATUS_OK(TestGraphTransformer(
+ build, /*opset_version=*/21, *logger_, MakeWebGpuTransformer(),
+ TransformerLevel::Level2, /*steps=*/1, nullptr, CheckUnfusedGraph));
+}
+
+TEST_F(GraphTransformationTests, GroupQueryAttentionPreNormFusionSkipsCpuEp) {
+ // Build the pattern but assign all nodes to CPU EP. The fusion is gated to WebGPU only,
+ // so the graph must remain unfused.
+ auto build = [](ModelTestBuilder& builder) {
+ BuildQwenQkPostNormPattern(builder, BuildOptions{});
+ for (auto& node : builder.graph_.Nodes()) {
+ const_cast(node).SetExecutionProviderType(kCpuExecutionProvider);
+ }
+ };
+ ASSERT_STATUS_OK(TestGraphTransformer(
+ build, /*opset_version=*/21, *logger_, MakeWebGpuTransformer(),
+ TransformerLevel::Level2, /*steps=*/1, nullptr, CheckUnfusedGraph));
+}
+
+TEST_F(GraphTransformationTests, GroupQueryAttentionPreNormFusionSkipsJsEp) {
+ // JSEP does not implement the fused per-head Q/K RMSNorm prologue, so the optimizer
+ // (which we now register for WebGPU only) must leave JSEP-assigned graphs alone.
+ auto build = [](ModelTestBuilder& builder) {
+ BuildQwenQkPostNormPattern(builder, BuildOptions{});
+ for (auto& node : builder.graph_.Nodes()) {
+ const_cast(node).SetExecutionProviderType(kJsExecutionProvider);
+ }
+ };
+ ASSERT_STATUS_OK(TestGraphTransformer(
+ build, /*opset_version=*/21, *logger_, MakeWebGpuTransformer(),
+ TransformerLevel::Level2, /*steps=*/1, nullptr, CheckUnfusedGraph));
+}
+
+TEST_F(GraphTransformationTests, GroupQueryAttentionPreNormFusionSkipsWhenDoRotaryDisabled) {
+ // The WebGPU fused prologue requires do_rotary=1; the optimizer must skip otherwise so
+ // the runtime guard never trips.
+ BuildOptions opts;
+ opts.do_rotary = 0;
+ auto build = [opts](ModelTestBuilder& builder) { BuildQwenQkPostNormPattern(builder, opts); };
+ ASSERT_STATUS_OK(TestGraphTransformer(
+ build, /*opset_version=*/21, *logger_, MakeWebGpuTransformer(),
+ TransformerLevel::Level2, /*steps=*/1, nullptr, CheckUnfusedGraph));
+}
+
+TEST_F(GraphTransformationTests, GroupQueryAttentionPreNormFusionSkipsPackedQkv) {
+ // Packed-QKV form leaves slots 1 and 2 empty; the WebGPU fused prologue does not support
+ // it, so the optimizer must skip the rewrite.
+ BuildOptions opts;
+ opts.packed_qkv = true;
+ auto build = [opts](ModelTestBuilder& builder) { BuildQwenQkPostNormPattern(builder, opts); };
+ ASSERT_STATUS_OK(TestGraphTransformer(
+ build, /*opset_version=*/21, *logger_, MakeWebGpuTransformer(),
+ TransformerLevel::Level2, /*steps=*/1, nullptr, CheckUnfusedGraph));
+}
+
+TEST_F(GraphTransformationTests, GroupQueryAttentionPreNormFusionSkipsAlreadyFusedNode) {
+ // If the GQA node already exposes a q_norm_weight (slot 14) input the optimizer must
+ // treat it as already fused and leave the surrounding SLN/Reshape ops in place. The
+ // standard CheckUnfusedGraph helper rejects any wiring at slot 14, so use a custom
+ // checker that only verifies the surrounding ops weren't removed.
+ BuildOptions opts;
+ opts.pre_fused = true;
+ auto build = [opts](ModelTestBuilder& builder) { BuildQwenQkPostNormPattern(builder, opts); };
+ auto check = [](Graph& graph) -> Status {
+ const auto op_to_count = CountOpsInGraph(graph);
+ ORT_RETURN_IF_NOT(OpCount(op_to_count, "com.microsoft.GroupQueryAttention") == 1,
+ "Already-fused test: GQA count changed.");
+ ORT_RETURN_IF_NOT(OpCount(op_to_count, "SimplifiedLayerNormalization") == 2,
+ "Already-fused test: SLN ops were removed.");
+ ORT_RETURN_IF_NOT(OpCount(op_to_count, "Reshape") == 4,
+ "Already-fused test: Reshape ops were removed.");
+ return Status::OK();
+ };
+ ASSERT_STATUS_OK(TestGraphTransformer(
+ build, /*opset_version=*/21, *logger_, MakeWebGpuTransformer(),
+ TransformerLevel::Level2, /*steps=*/1, nullptr, check));
+}
+
+#endif // !defined(DISABLE_CONTRIB_OPS)
+
+} // namespace test
+} // namespace onnxruntime