diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 6df316097e719..d830585a708a8 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -2625,6 +2625,8 @@ This version of the operator has been available since version 1 of the 'com.micr
left_window_size for local attention (like Mistral). Default value is -1 meaning unused.
num_heads : int (required)
Number of attention heads for q
+
qk_norm_epsilon : float
+
Epsilon used by the per-head RMS norm applied to Q and K when q_norm_weight and k_norm_weight inputs are provided. Default value is 1e-6.
qk_output : int
Output values of QK matrix multiplication before (1) or after (2) softmax normalization. Default value is 0 (don't output).
rotary_interleaved : int
@@ -2639,7 +2641,7 @@ This version of the operator has been available since version 1 of the 'com.micr
Quantization type for V cache. One of 'NONE', 'PER_TENSOR', 'PER_CHANNEL'.
-#### Inputs (7 - 14) +#### Inputs (7 - 16)
query : T
@@ -2670,6 +2672,10 @@ This version of the operator has been available since version 1 of the 'com.micr
Scale tensor for past_key.
v_scale (optional) : T_KV_SCALE
Scale tensor for past_value.
+
q_norm_weight (optional) : T
+
Optional 1D tensor of shape (head_size). When provided together with k_norm_weight, the kernel applies a per-head RMS normalization to Q (and K) before any rotary embedding. Used by Qwen3-style models that wrap their Q/K projections in a Reshape -> SimplifiedLayerNormalization -> Reshape stack; downstream graph fusion folds that pattern into this input. Currently honored by the native WebGPU execution provider only; JSEP WebGPU/JS and other EPs must reject the node when this input is set.
+
k_norm_weight (optional) : T
+
Optional 1D tensor of shape (head_size). See q_norm_weight. Must be provided together with q_norm_weight.
#### Outputs (3 - 4) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index d0d8e750285d4..4a2399c000f46 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -1088,7 +1088,7 @@ The **OpSet Version** column uses the following notation: |GreedySearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*out* sequences:**I**|1+|**T** = tensor(float), tensor(float16)| |GridSample|*in* X:**T1**
*in* Grid:**T1**
*out* Y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(float)| |GroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| -|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T_CACHE**
*in* past_value:**T_CACHE**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**tensor(int64)**
*in* attention_bias:**T**
*in* head_sink:**T**
*in* k_scale:**T_KV_SCALE**
*in* v_scale:**T_KV_SCALE**
*out* output:**T**
*out* present_key:**T_CACHE**
*out* present_value:**T_CACHE**
*out* output_qk:**T**|1+|**M** = tensor(int32)
**T** = tensor(bfloat16), tensor(float16)
**T_CACHE** = tensor(bfloat16), tensor(float16), tensor(float8e4m3fn), tensor(int8)
**T_KV_SCALE** = tensor(float)| +|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T_CACHE**
*in* past_value:**T_CACHE**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**tensor(int64)**
*in* attention_bias:**T**
*in* head_sink:**T**
*in* k_scale:**T_KV_SCALE**
*in* v_scale:**T_KV_SCALE**
*in* q_norm_weight:**T**
*in* k_norm_weight:**T**
*out* output:**T**
*out* present_key:**T_CACHE**
*out* present_value:**T_CACHE**
*out* output_qk:**T**|1+|**M** = tensor(int32)
**T** = tensor(bfloat16), tensor(float16)
**T_CACHE** = tensor(bfloat16), tensor(float16), tensor(float8e4m3fn), tensor(int8)
**T_KV_SCALE** = tensor(float)| |Inverse|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |Irfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |LinearAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_state:**S**
*in* decay:**T**
*in* beta:**T**
*out* output:**T**
*out* present_state:**S**|1+|**T** = tensor(float), tensor(float16)| @@ -1575,7 +1575,7 @@ The **OpSet Version** column uses the following notation: |FusedMatMulActivation|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| |Gelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| |GroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*out* Y:**T**|1+|**M** = tensor(float), tensor(float16)
**T** = tensor(float), tensor(float16)| -|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T_CACHE**
*in* past_value:**T_CACHE**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**tensor(int64)**
*in* attention_bias:**T**
*in* head_sink:**T**
*in* k_scale:**T_KV_SCALE**
*in* v_scale:**T_KV_SCALE**
*out* output:**T**
*out* present_key:**T_CACHE**
*out* present_value:**T_CACHE**
*out* output_qk:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| +|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T_CACHE**
*in* past_value:**T_CACHE**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**tensor(int64)**
*in* attention_bias:**T**
*in* head_sink:**T**
*in* k_scale:**T_KV_SCALE**
*in* v_scale:**T_KV_SCALE**
*in* q_norm_weight:**T**
*in* k_norm_weight:**T**
*out* output:**T**
*out* present_key:**T_CACHE**
*out* present_value:**T_CACHE**
*out* output_qk:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| |MatMulIntegerToFloat|*in* A:**T1**
*in* B:**T2**
*in* a_scale:**T3**
*in* b_scale:**T3**
*in* a_zero_point:**T1**
*in* b_zero_point:**T2**
*in* bias:**T3**
*out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(float), tensor(float16)| |MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T3**
*in* g_idx:**T4**
*in* bias:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(uint8)| |MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* attention_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* past_sequence_length:**M**
*in* cache_indirection:**M**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**
*out* qk:**QK**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| diff --git a/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts index 9050c1bbb8816..ada0c65bdd8a7 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts @@ -328,6 +328,16 @@ const generatePositionIdsProgramInfo = ( }; export const groupQueryAttention = (context: ComputeContext, attributes: GroupQueryAttentionAttributes): void => { + // q_norm_weight (input 14) / k_norm_weight (input 15) are emitted by the WebGPU-only + // GroupQueryAttentionPreNormFusion optimizer pass. JSEP does not implement the fused + // per-head Q/K RMS normalization prologue, so reject the node if either input is present + // (regardless of rank, including scalars) rather than silently dropping the normalization. + if ((context.inputs.length > 14 && context.inputs[14]) || (context.inputs.length > 15 && context.inputs[15])) { + throw new Error( + 'GroupQueryAttention (JSEP): q_norm_weight / k_norm_weight inputs are not supported. ' + + 'The per-head Q/K RMS normalization prologue is implemented only on the native WebGPU EP.', + ); + } const params = validateInputs(context.inputs, attributes); if (context.inputs[0].dims.length === 5) { throw new Error('Packed QKV is not implemented'); diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc index 4df5f6a349599..8eb7c73f8a4a9 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc @@ -84,6 +84,18 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { "kv_cache_bit_width must be 0 when quantization is disabled, got ", kv_cache_bit_width_); } + // q_norm_weight (input 14) / k_norm_weight (input 15) are populated by the WebGPU-only + // GroupQueryAttentionPreNormFusion optimizer pass. The CPU kernel does not implement + // the fused per-head Q/K RMS normalization prologue, so reject the node if either input + // is present rather than silently dropping the normalization. + if ((context->InputCount() > 14 && context->Input(14) != nullptr) || + (context->InputCount() > 15 && context->Input(15) != nullptr)) { + return ORT_MAKE_STATUS( + ONNXRUNTIME, INVALID_ARGUMENT, + "GroupQueryAttention (CPU): q_norm_weight / k_norm_weight inputs are not supported. " + "The per-head Q/K RMS normalization prologue is implemented only on the WebGPU EP."); + } + GroupQueryAttentionParameters parameters = {}; ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckInputs(query, key, diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc index 44408e6ce4af9..ea84fb973091c 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc @@ -167,6 +167,18 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons const Tensor* k_scale = context->Input(12); const Tensor* v_scale = context->Input(13); + // q_norm_weight (input 14) / k_norm_weight (input 15) are populated by the WebGPU-only + // GroupQueryAttentionPreNormFusion optimizer pass. The CUDA kernel does not implement + // the fused per-head Q/K RMS normalization prologue, so reject the node if either input + // is present rather than silently dropping the normalization. + if ((context->InputCount() > 14 && context->Input(14) != nullptr) || + (context->InputCount() > 15 && context->Input(15) != nullptr)) { + return ORT_MAKE_STATUS( + ONNXRUNTIME, INVALID_ARGUMENT, + "GroupQueryAttention (CUDA): q_norm_weight / k_norm_weight inputs are not supported. " + "The per-head Q/K RMS normalization prologue is implemented only on the WebGPU EP."); + } + if (k_quant_type_ != KVQuantizationType::NONE) { if (k_scale == nullptr) { return ORT_MAKE_STATUS( diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc index e3b91bdbb82f4..7fd1271e37132 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc @@ -9,6 +9,7 @@ #include "contrib_ops/webgpu/bert/flash_attention.h" #include "core/common/narrow.h" +#include "core/providers/webgpu/nn/layer_norm.h" #include "core/providers/webgpu/webgpu_supported_types.h" #include "core/providers/webgpu/shader_helper.h" @@ -104,7 +105,11 @@ Status RunSplitPackedQKVWithRotaryEmbedding(onnxruntime::webgpu::ComputeContext& return context.RunProgram(program); } -// Fused Q/K rotary embedding +// Fused Q/K rotary embedding. When q_norm_weight and k_norm_weight are non-null, a per-head +// RMS normalization (Q[c] *= inverseSqrt(mean(Q[..]^2)+eps) * q_norm_weight[c]; same for K) +// is fused into the rotary kernel ahead of the rotation. This decode-only fast path replaces +// the standalone SimplifiedLayerNormalization dispatches that GroupQueryAttentionPreNormFusion +// folds away. Status RunFusedQKRotaryEmbedding(onnxruntime::webgpu::ComputeContext& context, const WebgpuAttentionParameters& params, const Tensor* query_in, @@ -113,7 +118,10 @@ Status RunFusedQKRotaryEmbedding(onnxruntime::webgpu::ComputeContext& context, const Tensor* cos_cache, const Tensor* sin_cache, Tensor* query_out, - Tensor* key_out) { + Tensor* key_out, + const Tensor* q_norm_weight = nullptr, + const Tensor* k_norm_weight = nullptr, + float qk_norm_epsilon = 0.0f) { const auto half_rotary_embedding_dim = gsl::narrow_cast(cos_cache->Shape()[1]); const auto head_size = params.head_size_; @@ -155,9 +163,10 @@ Status RunFusedQKRotaryEmbedding(onnxruntime::webgpu::ComputeContext& context, 1u}); // Dispatch computations only over the Q domain, and fuse K write operations using a head-index-based condition. - FusedQKRotaryEmbeddingProgram program(params.rotary_interleaved_); + const bool has_qk_norm = (q_norm_weight != nullptr) && (k_norm_weight != nullptr); + FusedQKRotaryEmbeddingProgram program(params.rotary_interleaved_, has_qk_norm); program - .CacheHint(params.rotary_interleaved_) + .CacheHint(params.rotary_interleaved_, has_qk_norm) .AddInputs({ {query_in, ProgramTensorMetadataDependency::TypeAndRank}, {key_in, ProgramTensorMetadataDependency::Rank}, @@ -178,8 +187,17 @@ Status RunFusedQKRotaryEmbedding(onnxruntime::webgpu::ComputeContext& context, {gsl::make_span(k_global_dims)}, {gsl::make_span(k_input_output_strides)}, {q_domain_size}, + {static_cast(head_size)}, + {qk_norm_epsilon}, }); + if (has_qk_norm) { + program.AddInputs({ + {q_norm_weight, ProgramTensorMetadataDependency::Type}, + {k_norm_weight, ProgramTensorMetadataDependency::Type}, + }); + } + return context.RunProgram(program); } @@ -196,6 +214,20 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& const Tensor* position_ids = context.Input(9); // TODO: support sliding window const Tensor* attention_bias = context.Input(10); const Tensor* head_sink = context.Input(11); + // Inputs 12 and 13 are k_scale / v_scale (KV-cache quant). Not consumed by WebGPU yet. + // Inputs 14 and 15 are q_norm_weight / k_norm_weight, populated by + // GroupQueryAttentionPreNormFusion. WebGPU supports these inputs for the configurations + // validated below (do_rotary, non-packed Q/K/V). + const Tensor* q_norm_weight = context.InputCount() > 14 ? context.Input(14) : nullptr; + const Tensor* k_norm_weight = context.InputCount() > 15 ? context.Input(15) : nullptr; + const bool has_qk_norm = (q_norm_weight != nullptr) && (k_norm_weight != nullptr); + // The current fused prologue only supports the Qwen3-style configuration that + // GroupQueryAttentionPreNormFusion targets: do_rotary, non-packed Q/K/V. Reject any + // other configuration so downstream rewrites cannot land silently. + ORT_RETURN_IF(((q_norm_weight != nullptr) ^ (k_norm_weight != nullptr)), + "GroupQueryAttention: q_norm_weight and k_norm_weight must be provided together."); + ORT_RETURN_IF(has_qk_norm && !do_rotary_, + "GroupQueryAttention: q_norm_weight / k_norm_weight require do_rotary=1."); GroupQueryAttentionParameters params = {}; ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckInputs(query, @@ -227,6 +259,23 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& static_cast(Info().GetAttrOrDefault("qk_output", static_cast(QKOutputType::NO_OUTPUT))))); WebgpuAttentionParameters parameters(params); + ORT_RETURN_IF(has_qk_norm && parameters.is_packed_qkv_, + "GroupQueryAttention: q_norm_weight / k_norm_weight are not supported when QKV is packed."); + if (has_qk_norm) { + // The fused prologue indexes q/k_norm_weight as a 1-D tensor of length head_size. Validate + // shape here so a hand-authored model with a wrong shape fails with INVALID_ARGUMENT instead + // of silently reading the wrong offsets (or out of bounds). + const auto& q_norm_shape = q_norm_weight->Shape(); + ORT_RETURN_IF_NOT(q_norm_shape.NumDimensions() == 1 && + q_norm_shape[0] == static_cast(parameters.head_size_), + "GroupQueryAttention: q_norm_weight must be a 1-D tensor of shape [head_size=", + parameters.head_size_, "], got ", q_norm_shape.ToString(), "."); + const auto& k_norm_shape = k_norm_weight->Shape(); + ORT_RETURN_IF_NOT(k_norm_shape.NumDimensions() == 1 && + k_norm_shape[0] == static_cast(parameters.head_size_), + "GroupQueryAttention: k_norm_weight must be a 1-D tensor of shape [head_size=", + parameters.head_size_, "], got ", k_norm_shape.ToString(), "."); + } TensorShapeVector output_shape(3); output_shape[0] = static_cast(parameters.batch_size_); output_shape[1] = static_cast(parameters.sequence_length_); @@ -304,16 +353,63 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& value = &vSplit; } if (do_rotary_) { + // Per-head RMS normalization handling for Qwen3-style models (GQA inputs 14/15). + // - Decode (sequence_length == 1): fold the norm into the FusedQKRotaryEmbedding + // kernel. Each thread re-reads its head's head_size channels (Approach A); no + // reductions, no shared memory. Sub-microsecond overhead vs ~60us/layer SLN savings. + // - Prefill (sequence_length > 1): fall back to two standalone SimplifiedLayerNorm + // dispatches into scratch tensors, then run the unfused FusedQKRotaryEmbedding. + // Matches the pre-fusion graph timing exactly so prefill cannot regress. + Tensor qNorm; + Tensor kNorm; + const Tensor* q_for_rotary = query; + const Tensor* k_for_rotary = key; + const Tensor* q_norm_for_fused = nullptr; + const Tensor* k_norm_for_fused = nullptr; + const bool decode_norm_fast_path = has_qk_norm && parameters.sequence_length_ == 1; + if (has_qk_norm && !decode_norm_fast_path) { + qNorm = context.CreateGPUTensor(query->DataType(), query->Shape()); + kNorm = context.CreateGPUTensor(key->DataType(), key->Shape()); + const uint32_t q_norm_count = + static_cast(parameters.batch_size_) * + static_cast(parameters.sequence_length_) * + static_cast(parameters.num_heads_); + const uint32_t k_norm_count = + static_cast(parameters.batch_size_) * + static_cast(parameters.sequence_length_) * + static_cast(parameters.kv_num_heads_); + ORT_RETURN_IF_ERROR(onnxruntime::webgpu::RunLayerNormProgram( + context, query, q_norm_weight, /*bias=*/nullptr, qk_norm_epsilon_, + q_norm_count, static_cast(parameters.head_size_), + /*simplified=*/true, &qNorm, /*mean=*/nullptr, /*inv_std_dev=*/nullptr)); + ORT_RETURN_IF_ERROR(onnxruntime::webgpu::RunLayerNormProgram( + context, key, k_norm_weight, /*bias=*/nullptr, qk_norm_epsilon_, + k_norm_count, static_cast(parameters.head_size_), + /*simplified=*/true, &kNorm, /*mean=*/nullptr, /*inv_std_dev=*/nullptr)); + q_for_rotary = &qNorm; + k_for_rotary = &kNorm; + } else if (decode_norm_fast_path) { + q_norm_for_fused = q_norm_weight; + k_norm_for_fused = k_norm_weight; + } // rotary QK - qRotary = context.CreateGPUTensor(query->DataType(), query->Shape()); - kRotary = context.CreateGPUTensor(key->DataType(), key->Shape()); + qRotary = context.CreateGPUTensor(q_for_rotary->DataType(), q_for_rotary->Shape()); + kRotary = context.CreateGPUTensor(k_for_rotary->DataType(), k_for_rotary->Shape()); ORT_RETURN_IF_ERROR(RunFusedQKRotaryEmbedding(context, parameters, - query, key, + q_for_rotary, k_for_rotary, seqlen_k, cos_cache, sin_cache, - &qRotary, &kRotary)); + &qRotary, &kRotary, + q_norm_for_fused, k_norm_for_fused, + qk_norm_epsilon_)); query = &qRotary; key = &kRotary; + } else if (has_qk_norm) { + // Defensive: do_rotary_ guard above should make this unreachable, but keep it + // explicit so a future schema/config drift surfaces as a clear error. + return ORT_MAKE_STATUS( + ONNXRUNTIME, NOT_IMPLEMENTED, + "GroupQueryAttention: q/k norm weights require do_rotary=1 (no rotary, no norm path)."); } } diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.h b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.h index 4127a8928f38e..cbb5b806eb6ad 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.h @@ -58,6 +58,8 @@ class GroupQueryAttention final : public WebGpuKernel { use_smooth_softmax_ = info.GetAttrOrDefault("smooth_softmax", 0) == 1; local_window_size_ = static_cast(info.GetAttrOrDefault("local_window_size", -1)); + + qk_norm_epsilon_ = info.GetAttrOrDefault("qk_norm_epsilon", 1e-6f); } int num_heads_; // number of attention heads of Q @@ -69,6 +71,10 @@ class GroupQueryAttention final : public WebGpuKernel { int local_window_size_; bool use_smooth_softmax_; + // Epsilon used by per-head RMSNorm when q_norm_weight / k_norm_weight (inputs 14 / 15) are + // provided. Consumed whenever those optional norm inputs are used (decode fast path or + // prefill fallback), and ignored otherwise. + float qk_norm_epsilon_; Status ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const override; }; diff --git a/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc b/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc index 69d2db391ce3c..58f7b54bd8840 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc @@ -66,22 +66,95 @@ Status RotaryEmbeddingProgram::GenerateShaderCode(ShaderHelper& shader) const { } Status FusedQKRotaryEmbeddingProgram::GenerateShaderCode(ShaderHelper& shader) const { - // Inputs - const auto& q_input = shader.AddInput("q_input", ShaderUsage::UseUniform); - const auto& k_input = shader.AddInput("k_input", ShaderUsage::UseUniform); + // Inputs. q_input/k_input use the element-type alias when has_qk_norm_ is true so we can + // mix in the f32-computed inverse-RMS scale at element-type precision. + const ShaderUsage qk_input_usage = has_qk_norm_ + ? (ShaderUsage::UseUniform | ShaderUsage::UseElementTypeAlias) + : ShaderUsage::UseUniform; + const auto& q_input = shader.AddInput("q_input", qk_input_usage); + const auto& k_input = shader.AddInput("k_input", qk_input_usage); const auto& seqlens = shader.AddInput("seqlens", ShaderUsage::UseUniform); const auto& cos_cache = shader.AddInput("cos_cache", ShaderUsage::UseUniform); const auto& sin_cache = shader.AddInput("sin_cache", ShaderUsage::UseUniform); + + // Optional per-head RMS norm weights (1D tensors of length head_size). When present, + // a fused per-head normalization is applied to Q/K before the rotary rotation: + // x_norm[c] = x[c] * inverseSqrt(mean(x[..]^2) + epsilon) * weight[c] + // Decode-only fast path: each thread re-reads its own head's head_size channels to + // compute the sum-of-squares (no reductions, no shared memory). The redundant L1 + // traffic is sub-microsecond on Qwen3-1.7B decode geometry. + if (has_qk_norm_) { + shader.AddInput("q_norm_weight", ShaderUsage::UseUniform); + shader.AddInput("k_norm_weight", ShaderUsage::UseUniform); + } + // Outputs const auto& q_output = shader.AddOutput("q_output", ShaderUsage::UseUniform); const auto& k_output = shader.AddOutput("k_output", ShaderUsage::UseUniform); const auto interleaved_str = interleaved_ ? "true" : "false"; - shader.MainFunctionBody() + auto& body = shader.MainFunctionBody(); + body << " if (global_idx >= uniforms.q_domain_size) { return; }\n" << " let half_rotary_dim = uniforms.cos_cache_shape[1];\n" << " let bsnh = global_idx / uniforms.q_global_stride % uniforms.q_global_shape;\n" + << " let needs_k = bsnh[2] < uniforms.k_global_shape[2];\n"; + + // Per-head RMS computation (Approach A, no reductions). For non-interleaved layouts the + // bsnh[3] coordinate is the lower channel of a rotary pair, so the head base offset is + // dot(bsnh, stride) - bsnh[3] (i.e. drop the channel contribution). q_input_output_stride[3] + // is always 1 (channel stride), so subtracting bsnh[3] gives the head's channel-0 offset + // for both interleaved and non-interleaved layouts in the rotated branch. In the + // passthrough else-branch we recompute from bsnh[0..2] explicitly. + if (has_qk_norm_) { + body + << " let q_head_base = bsnh[0] * uniforms.q_input_output_stride[0]\n" + << " + bsnh[1] * uniforms.q_input_output_stride[1]\n" + << " + bsnh[2] * uniforms.q_input_output_stride[2];\n" + << " var q_sumsq: f32 = 0.0;\n" + << " for (var c: u32 = 0u; c < uniforms.head_size; c = c + 1u) {\n" + << " let q_v = f32(" << q_input.GetByOffset("q_head_base + c") << ");\n" + << " q_sumsq = q_sumsq + q_v * q_v;\n" + << " }\n" + << " let q_inv_rms = q_input_element_t(inverseSqrt(q_sumsq / f32(uniforms.head_size) + uniforms.qk_norm_epsilon));\n" + << " let k_head_base = bsnh[0] * uniforms.k_input_output_stride[0]\n" + << " + bsnh[1] * uniforms.k_input_output_stride[1]\n" + << " + bsnh[2] * uniforms.k_input_output_stride[2];\n" + << " var k_inv_rms = k_input_element_t(0);\n" + << " if (needs_k) {\n" + << " var k_sumsq: f32 = 0.0;\n" + << " for (var c: u32 = 0u; c < uniforms.head_size; c = c + 1u) {\n" + << " let k_v = f32(" << k_input.GetByOffset("k_head_base + c") << ");\n" + << " k_sumsq = k_sumsq + k_v * k_v;\n" + << " }\n" + << " k_inv_rms = k_input_element_t(inverseSqrt(k_sumsq / f32(uniforms.head_size) + uniforms.qk_norm_epsilon));\n" + << " }\n"; + } + + // Helpers that load Q/K and (when has_qk_norm_) apply the fused per-channel norm scale. + // The channel index expressions match the qi/qj/ki/kj/qk/kk computations used below. + auto load_q = [&](const std::string& off, const std::string& chan) { + if (!has_qk_norm_) { + return q_input.GetByOffset(off); + } + return std::string("(") + q_input.GetByOffset(off) + " * q_inv_rms * q_norm_weight[" + chan + "])"; + }; + auto load_k = [&](const std::string& off, const std::string& chan) { + if (!has_qk_norm_) { + return k_input.GetByOffset(off); + } + return std::string("(") + k_input.GetByOffset(off) + " * k_inv_rms * k_norm_weight[" + chan + "])"; + }; + + // Channel index expressions for the rotated branch. For interleaved layout the pair is + // (2*bsnh[3], 2*bsnh[3]+1); otherwise it is (bsnh[3], bsnh[3]+half_rotary_dim). + const std::string c_i = interleaved_ ? "(2u * bsnh[3])" : "bsnh[3]"; + const std::string c_j = interleaved_ ? "(2u * bsnh[3] + 1u)" : "(bsnh[3] + half_rotary_dim)"; + // Channel index for the passthrough else-branch (only fires when head_size > 2 * half_rotary_dim). + const std::string c_k = "(bsnh[3] + half_rotary_dim)"; + + body << " if (bsnh[3] < half_rotary_dim) {\n" << " let batch_idx = bsnh[0];\n" << " let sequence_idx = bsnh[1];\n" @@ -89,46 +162,51 @@ Status FusedQKRotaryEmbeddingProgram::GenerateShaderCode(ShaderHelper& shader) c << " let seqlen = u32(seqlen_i);\n" << " let total_seqlen = seqlen + 1u;\n" << " let past_seqlen = total_seqlen - uniforms.q_global_shape[1];\n" - // position_id is derived from past_seqlen + sequence_idx (always non-negative). << " let position_id = past_seqlen + sequence_idx;\n" << " let qi = dot(bsnh, uniforms.q_input_output_stride) + select(0u, bsnh[3], " << interleaved_str << ");\n" << " let qj = qi + select(half_rotary_dim, 1u, " << interleaved_str << ");\n" - // Bounds check: position_id must be within cos/sin cache range. - // On OOB, pass through input unchanged (same as CUDA kernel behavior). - " let max_position = uniforms.cos_cache_shape[0];\n" - " if (position_id >= max_position) {\n" - << " " << q_output.SetByOffset("qi", q_input.GetByOffset("qi")) << "\n" - << " " << q_output.SetByOffset("qj", q_input.GetByOffset("qj")) << "\n" - << " if (bsnh[2] < uniforms.k_global_shape[2]) {\n" + << " let q_at_qi = " << load_q("qi", c_i) << ";\n" + << " let q_at_qj = " << load_q("qj", c_j) << ";\n" + << " let max_position = uniforms.cos_cache_shape[0];\n" + << " if (position_id >= max_position) {\n" + // Bounds check: position_id must be within cos/sin cache range. + // On OOB, pass through input (norm-applied if has_qk_norm_) unchanged. + << " " << q_output.SetByOffset("qi", "q_at_qi") << "\n" + << " " << q_output.SetByOffset("qj", "q_at_qj") << "\n" + << " if (needs_k) {\n" << " let ki = dot(bsnh, uniforms.k_input_output_stride) + select(0u, bsnh[3], " << interleaved_str << ");\n" << " let kj = ki + select(half_rotary_dim, 1u, " << interleaved_str << ");\n" - << " " << k_output.SetByOffset("ki", k_input.GetByOffset("ki")) << "\n" - << " " << k_output.SetByOffset("kj", k_input.GetByOffset("kj")) << "\n" - " }\n" - " } else {\n" + << " let k_at_ki = " << load_k("ki", c_i) << ";\n" + << " let k_at_kj = " << load_k("kj", c_j) << ";\n" + << " " << k_output.SetByOffset("ki", "k_at_ki") << "\n" + << " " << k_output.SetByOffset("kj", "k_at_kj") << "\n" + << " }\n" + << " } else {\n" << " let cos_v = " << cos_cache.GetByIndices("vec2(position_id, bsnh[3])") << ";\n" << " let sin_v = " << sin_cache.GetByIndices("vec2(position_id, bsnh[3])") << ";\n" - << " let q_re = " << q_input.GetByOffset("qi") << " * cos_v - " << q_input.GetByOffset("qj") << " * sin_v;\n" + << " let q_re = q_at_qi * cos_v - q_at_qj * sin_v;\n" << " " << q_output.SetByOffset("qi", "q_re") << "\n" - << " let q_im = " << q_input.GetByOffset("qi") << " * sin_v + " << q_input.GetByOffset("qj") << " * cos_v;\n" + << " let q_im = q_at_qi * sin_v + q_at_qj * cos_v;\n" << " " << q_output.SetByOffset("qj", "q_im") << "\n" - // Conditionally process Key (only for heads that exist in K domain) - << " if (bsnh[2] < uniforms.k_global_shape[2]) {\n" + << " if (needs_k) {\n" << " let ki = dot(bsnh, uniforms.k_input_output_stride) + select(0u, bsnh[3], " << interleaved_str << ");\n" << " let kj = ki + select(half_rotary_dim, 1u, " << interleaved_str << ");\n" - << " let k_re = " << k_input.GetByOffset("ki") << " * cos_v - " << k_input.GetByOffset("kj") << " * sin_v;\n" + << " let k_at_ki = " << load_k("ki", c_i) << ";\n" + << " let k_at_kj = " << load_k("kj", c_j) << ";\n" + << " let k_re = k_at_ki * cos_v - k_at_kj * sin_v;\n" << " " << k_output.SetByOffset("ki", "k_re") << "\n" - << " let k_im = " << k_input.GetByOffset("ki") << " * sin_v + " << k_input.GetByOffset("kj") << " * cos_v;\n" + << " let k_im = k_at_ki * sin_v + k_at_kj * cos_v;\n" << " " << k_output.SetByOffset("kj", "k_im") << "\n" - " }\n" - " }\n" + << " }\n" + << " }\n" << " } else {\n" << " let qk = dot(bsnh, uniforms.q_input_output_stride) + half_rotary_dim;\n" - << " " << q_output.SetByOffset("qk", q_input.GetByOffset("qk")) << "\n" - // Conditionally process Key (only for heads that exist in K domain) - << " if (bsnh[2] < uniforms.k_global_shape[2]) {\n" + << " let q_at_qk = " << load_q("qk", c_k) << ";\n" + << " " << q_output.SetByOffset("qk", "q_at_qk") << "\n" + << " if (needs_k) {\n" << " let kk = dot(bsnh, uniforms.k_input_output_stride) + half_rotary_dim;\n" - << " " << k_output.SetByOffset("kk", k_input.GetByOffset("kk")) << "\n" + << " let k_at_kk = " << load_k("kk", c_k) << ";\n" + << " " << k_output.SetByOffset("kk", "k_at_kk") << "\n" << " }\n" << " }\n"; return Status::OK(); diff --git a/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.h b/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.h index e3dc4468cb3ed..dd16630e436bc 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.h +++ b/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.h @@ -31,12 +31,20 @@ class RotaryEmbeddingProgram final : public Program { class FusedQKRotaryEmbeddingProgram final : public Program { public: - FusedQKRotaryEmbeddingProgram(bool interleaved) : Program{"FusedQKRotaryEmbedding"}, interleaved_{interleaved} {} + FusedQKRotaryEmbeddingProgram(bool interleaved, bool has_qk_norm) + : Program{"FusedQKRotaryEmbedding"}, + interleaved_{interleaved}, + has_qk_norm_{has_qk_norm} {} Status GenerateShaderCode(ShaderHelper& sh) const override; // q_* describes query rotation domain (same definition as existing program) - // k_* describes key rotation domain + // k_* describes key rotation domain. + // When has_qk_norm_ is true, the program also fuses a per-head RMS normalization + // (epsilon = qk_norm_epsilon, scale = q_norm_weight / k_norm_weight) over the + // head_size channels of Q and K before the rotary rotation. head_size and + // qk_norm_epsilon are required uniforms when has_qk_norm_ is true; they are + // ignored otherwise but must still be supplied (callers pass placeholder values). WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( {"scale", ProgramUniformVariableDataType::Float32}, {"q_global_shape", ProgramUniformVariableDataType::Uint32}, @@ -44,10 +52,13 @@ class FusedQKRotaryEmbeddingProgram final : public Program SimplifiedLayerNormalization -> Reshape stack; downstream graph fusion " + "folds that pattern into this input. Currently honored by the native WebGPU execution provider only; " + "JSEP WebGPU/JS and other EPs must reject the node when this input is set.", + "T", + OpSchema::Optional) + .Input(15, + "k_norm_weight", + "Optional 1D tensor of shape (head_size). See q_norm_weight. Must be provided together with q_norm_weight.", + "T", + OpSchema::Optional) .Output(0, "output", "3D output tensor with shape (batch_size, sequence_length, hidden_size)", diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index 702f20a96dccf..2f9a083a9b1be 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -55,6 +55,7 @@ #include "core/optimizer/layer_norm_fusion.h" #include "core/optimizer/matmul_activation_fusion.h" #include "core/optimizer/matmul_add_fusion.h" +#include "core/optimizer/group_query_attention_pre_norm_fusion.h" #include "core/optimizer/matmul_bn_fusion.h" #include "core/optimizer/matmul_integer_to_float.h" #include "core/optimizer/matmul_scale_fusion.h" @@ -446,6 +447,8 @@ InlinedVector> GenerateTransformers( #endif transformers.emplace_back(std::make_unique(cpu_ep)); + transformers.emplace_back(std::make_unique( + InlinedHashSet{onnxruntime::kWebGpuExecutionProvider})); #endif // !defined(DISABLE_CONTRIB_OPS) // The QDQFinalCleanupTransformer must run AFTER other transformers that fuse Q/DQ nodes. Otherwise, their diff --git a/onnxruntime/core/optimizer/group_query_attention_pre_norm_fusion.cc b/onnxruntime/core/optimizer/group_query_attention_pre_norm_fusion.cc new file mode 100644 index 0000000000000..909229822f134 --- /dev/null +++ b/onnxruntime/core/optimizer/group_query_attention_pre_norm_fusion.cc @@ -0,0 +1,403 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/optimizer/group_query_attention_pre_norm_fusion.h" + +#include +#include +#include +#include + +#include "core/graph/graph_utils.h" +#include "core/graph/node_attr_utils.h" +#include "core/optimizer/initializer.h" +#include "core/optimizer/utils.h" + +namespace onnxruntime { + +namespace { + +constexpr const char* kQkNormEpsilonAttrName = "qk_norm_epsilon"; +constexpr float kEpsilonTolerance = 1e-9f; + +bool HasInput(const Node& node, size_t index) { + return index < node.InputDefs().size() && node.InputDefs()[index] != nullptr && + !node.InputDefs()[index]->Name().empty(); +} + +bool HasProducedOutput(const Node& node, size_t index) { + return index < node.OutputDefs().size() && node.OutputDefs()[index] != nullptr && + !node.OutputDefs()[index]->Name().empty(); +} + +bool IsGraphOutput(const Graph& graph, const NodeArg* arg) { + if (arg == nullptr || arg->Name().empty()) { + return false; + } + for (const auto* graph_output : graph.GetOutputs()) { + if (graph_output != nullptr && graph_output->Name() == arg->Name()) { + return true; + } + } + return false; +} + +// Walks back from `consumer` via input slot `consumer_input_index` and matches: +// producer_proj -> Reshape(reshape_inner) -> SimplifiedLayerNormalization(sln) -> Reshape(reshape_outer) -> consumer +// (`reshape_inner` is the one closest to the projection: it reshapes the (batch, seq, hidden) +// tensor to (batch, seq, num_heads, head_size). `reshape_outer` is the one closest to the +// consumer: it folds back to (batch, seq, hidden).) +// On success returns true and fills the out-pointers. Each intermediate node must have a single +// consumer (the next op in the chain) and must not be a graph output. +bool MatchPreNormReshapeChain(Graph& graph, + Node& consumer, + int consumer_input_index, + int64_t expected_head_size, + int64_t expected_hidden_size, + Node*& reshape_outer_out, + Node*& sln_out, + Node*& reshape_inner_out, + NodeArg*& projection_arg_out, + NodeArg*& norm_weight_arg_out, + float& epsilon_out) { + reshape_outer_out = nullptr; + sln_out = nullptr; + reshape_inner_out = nullptr; + projection_arg_out = nullptr; + norm_weight_arg_out = nullptr; + epsilon_out = 0.0f; + + if (consumer_input_index < 0 || + static_cast(consumer_input_index) >= consumer.InputDefs().size()) { + return false; + } + + NodeArg* consumer_input = consumer.MutableInputDefs()[consumer_input_index]; + if (consumer_input == nullptr || consumer_input->Name().empty()) { + return false; + } + + Node* reshape_outer = graph.GetMutableProducerNode(consumer_input->Name()); + if (reshape_outer == nullptr || + !graph_utils::IsSupportedOptypeVersionAndDomain(*reshape_outer, "Reshape", {5, 13, 14, 19, 21, 23})) { + return false; + } + if (reshape_outer->GetOutputEdgesCount() != 1) { + return false; + } + if (IsGraphOutput(graph, reshape_outer->OutputDefs()[0])) { + return false; + } + + // Validate outer reshape output last dim equals hidden size (num_heads * head_size). + const auto* reshape_outer_shape = reshape_outer->OutputDefs()[0]->Shape(); + if (reshape_outer_shape == nullptr || reshape_outer_shape->dim_size() < 1) { + return false; + } + const auto& reshape_outer_last = reshape_outer_shape->dim(reshape_outer_shape->dim_size() - 1); + if (!reshape_outer_last.has_dim_value() || reshape_outer_last.dim_value() != expected_hidden_size) { + return false; + } + + if (reshape_outer->InputDefs().empty() || reshape_outer->InputDefs()[0] == nullptr) { + return false; + } + Node* sln = graph.GetMutableProducerNode(reshape_outer->InputDefs()[0]->Name()); + if (sln == nullptr || + !graph_utils::IsSupportedOptypeVersionAndDomain(*sln, "SimplifiedLayerNormalization", {1})) { + return false; + } + if (sln->GetOutputEdgesCount() != 1) { + return false; + } + if (IsGraphOutput(graph, sln->OutputDefs()[0])) { + return false; + } + // SLN may emit auxiliary outputs (mean / inv_std). They must not be consumed elsewhere. + for (size_t i = 1; i < sln->OutputDefs().size(); ++i) { + if (HasProducedOutput(*sln, i)) { + return false; + } + } + if (sln->InputDefs().size() < 2 || sln->InputDefs()[1] == nullptr || + sln->InputDefs()[1]->Name().empty()) { + return false; + } + + // SimplifiedLayerNormalization permits its input (T), scale (V) and output (T) to use different + // element types. The fused GroupQueryAttention input slots reuse the projection's element type + // (T), so we can only fuse when scale and output also use T -- otherwise the rewrite would + // change the node's type constraints and produce a semantically different graph. Require all + // three to match before fusing. + auto get_elem_type = [](const NodeArg* arg) -> int32_t { + if (arg == nullptr) { + return ONNX_NAMESPACE::TensorProto::UNDEFINED; + } + const auto* type_proto = arg->TypeAsProto(); + if (type_proto == nullptr || !type_proto->has_tensor_type() || + !type_proto->tensor_type().has_elem_type()) { + return ONNX_NAMESPACE::TensorProto::UNDEFINED; + } + return type_proto->tensor_type().elem_type(); + }; + const int32_t sln_input_elem_type = get_elem_type(sln->InputDefs()[0]); + const int32_t sln_scale_elem_type = get_elem_type(sln->InputDefs()[1]); + const int32_t sln_output_elem_type = get_elem_type(sln->OutputDefs()[0]); + if (sln_input_elem_type == ONNX_NAMESPACE::TensorProto::UNDEFINED || + sln_input_elem_type != sln_scale_elem_type || + sln_input_elem_type != sln_output_elem_type) { + return false; + } + + // Norm weight must be an initializer of shape [head_size]. + NodeArg* norm_weight_arg = sln->MutableInputDefs()[1]; + const ONNX_NAMESPACE::TensorProto* norm_weight_tensor = + graph_utils::GetConstantInitializer(graph, norm_weight_arg->Name()); + if (norm_weight_tensor == nullptr) { + return false; + } + if (norm_weight_tensor->dims_size() != 1 || norm_weight_tensor->dims(0) != expected_head_size) { + return false; + } + + const auto* sln_axis_attr = graph_utils::GetNodeAttribute(*sln, "axis"); + const int64_t sln_axis = (sln_axis_attr == nullptr) ? -1 : sln_axis_attr->i(); + if (sln_axis != -1) { + return false; + } + const auto* sln_eps_attr = graph_utils::GetNodeAttribute(*sln, "epsilon"); + const float sln_eps = (sln_eps_attr == nullptr) ? 1e-5f : sln_eps_attr->f(); + + // Inner reshape (between projection and SLN). + if (sln->InputDefs().empty() || sln->InputDefs()[0] == nullptr) { + return false; + } + Node* reshape_inner = graph.GetMutableProducerNode(sln->InputDefs()[0]->Name()); + if (reshape_inner == nullptr || + !graph_utils::IsSupportedOptypeVersionAndDomain(*reshape_inner, "Reshape", {5, 13, 14, 19, 21, 23})) { + return false; + } + if (reshape_inner->GetOutputEdgesCount() != 1) { + return false; + } + if (IsGraphOutput(graph, reshape_inner->OutputDefs()[0])) { + return false; + } + const auto* reshape_inner_shape = reshape_inner->OutputDefs()[0]->Shape(); + if (reshape_inner_shape == nullptr || reshape_inner_shape->dim_size() < 1) { + return false; + } + const auto& reshape_inner_last = reshape_inner_shape->dim(reshape_inner_shape->dim_size() - 1); + if (!reshape_inner_last.has_dim_value() || reshape_inner_last.dim_value() != expected_head_size) { + return false; + } + + if (reshape_inner->InputDefs().empty() || reshape_inner->InputDefs()[0] == nullptr) { + return false; + } + + reshape_outer_out = reshape_outer; + sln_out = sln; + reshape_inner_out = reshape_inner; + projection_arg_out = reshape_inner->MutableInputDefs()[0]; + norm_weight_arg_out = norm_weight_arg; + epsilon_out = sln_eps; + return true; +} + +} // namespace + +Status GroupQueryAttentionPreNormFusion::ApplyImpl(Graph& graph, + bool& modified, + int graph_level, + const logging::Logger& logger) const { + GraphViewer graph_viewer(graph); + const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); + + for (auto node_index : node_topology_list) { + auto* node_ptr = graph.GetNode(node_index); + if (node_ptr == nullptr) { + continue; + } + Node& node = *node_ptr; + ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger)); + + if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "GroupQueryAttention", {1}, kMSDomain) || + !graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders())) { + continue; + } + + // Already fused? + if (HasInput(node, 14) || HasInput(node, 15)) { + continue; + } + + // Need at least query (0), key (1), value (2), past_key (3) so we can read head_size. + // Requiring K at slot 1 also excludes the packed-QKV form (Q occupies slot 0 and K/V + // slots are empty), which the WebGPU fused prologue does not support. + if (node.InputDefs().size() < 4 || !HasInput(node, 0) || !HasInput(node, 1) || !HasInput(node, 2)) { + continue; + } + + // The fused decode prologue only applies when rotary embedding is enabled (Qwen3-style + // configuration). If the GQA node has do_rotary=0 the kernel will reject the rewritten + // node, so skip the fusion here to avoid that regression. + const auto& gqa_attrs = node.GetAttributes(); + auto do_rotary_it = gqa_attrs.find("do_rotary"); + const int64_t do_rotary = (do_rotary_it == gqa_attrs.end()) ? 0 : do_rotary_it->second.i(); + if (do_rotary != 1) { + continue; + } + const NodeArg* past_key_arg = node.InputDefs()[3]; + if (past_key_arg == nullptr || past_key_arg->Shape() == nullptr || + past_key_arg->Shape()->dim_size() < 4) { + continue; + } + const auto& head_size_dim = past_key_arg->Shape()->dim(3); + if (!head_size_dim.has_dim_value()) { + continue; + } + const int64_t head_size = head_size_dim.dim_value(); + + auto num_heads_it = gqa_attrs.find("num_heads"); + auto kv_num_heads_it = gqa_attrs.find("kv_num_heads"); + if (num_heads_it == gqa_attrs.end() || kv_num_heads_it == gqa_attrs.end()) { + continue; + } + const int64_t num_heads = num_heads_it->second.i(); + const int64_t kv_num_heads = kv_num_heads_it->second.i(); + const int64_t q_hidden_size = num_heads * head_size; + const int64_t kv_hidden_size = kv_num_heads * head_size; + + // Match pre-norm Reshape -> SLN -> Reshape on Q (slot 0) and K (slot 1). + Node* q_reshape_outer = nullptr; + Node* q_sln = nullptr; + Node* q_reshape_inner = nullptr; + NodeArg* q_projection_arg = nullptr; + NodeArg* q_norm_weight_arg = nullptr; + float q_epsilon = 0.0f; + if (!MatchPreNormReshapeChain(graph, node, /*consumer_input_index=*/0, head_size, q_hidden_size, + q_reshape_outer, q_sln, q_reshape_inner, + q_projection_arg, q_norm_weight_arg, q_epsilon)) { + continue; + } + + Node* k_reshape_outer = nullptr; + Node* k_sln = nullptr; + Node* k_reshape_inner = nullptr; + NodeArg* k_projection_arg = nullptr; + NodeArg* k_norm_weight_arg = nullptr; + float k_epsilon = 0.0f; + if (!MatchPreNormReshapeChain(graph, node, /*consumer_input_index=*/1, head_size, kv_hidden_size, + k_reshape_outer, k_sln, k_reshape_inner, + k_projection_arg, k_norm_weight_arg, k_epsilon)) { + continue; + } + + if (std::fabs(q_epsilon - k_epsilon) > kEpsilonTolerance) { + continue; + } + + LOGS(logger, VERBOSE) << "GroupQueryAttentionPreNormFusion: matched gqa='" << node.Name() + << "' q_sln='" << q_sln->Name() << "' k_sln='" << k_sln->Name() + << "' head_size=" << head_size + << " num_heads=" << num_heads << " kv_num_heads=" << kv_num_heads + << " epsilon=" << q_epsilon; + + // Build new GQA inputs: copy existing inputs, replace 0/1 with projection outputs, + // pad up to slot 13 with empty NodeArgs, then add q/k norm weights at 14/15. + NodeArg& empty_arg = graph.GetOrCreateNodeArg("", nullptr); + InlinedVector new_inputs; + new_inputs.reserve(16); + for (size_t i = 0; i < 16; ++i) { + if (i == 0) { + new_inputs.push_back(q_projection_arg); + } else if (i == 1) { + new_inputs.push_back(k_projection_arg); + } else if (i == 14) { + new_inputs.push_back(q_norm_weight_arg); + } else if (i == 15) { + new_inputs.push_back(k_norm_weight_arg); + } else if (i < node.InputDefs().size()) { + NodeArg* existing = node.MutableInputDefs()[i]; + new_inputs.push_back((existing != nullptr && !existing->Name().empty()) ? existing : &empty_arg); + } else { + new_inputs.push_back(&empty_arg); + } + } + + // Outputs: keep the same NodeArgs so downstream consumers and graph outputs are preserved. + InlinedVector new_outputs; + new_outputs.reserve(node.OutputDefs().size()); + for (auto* out : node.OutputDefs()) { + new_outputs.push_back(const_cast(out)); + } + + // Copy attributes and add qk_norm_epsilon. + NodeAttributes new_attrs = node.GetAttributes(); + utils::SetNodeAttribute(utils::MakeAttribute(std::string(kQkNormEpsilonAttrName), q_epsilon), new_attrs); + + const std::string original_name = node.Name(); + const std::string original_ep = node.GetExecutionProviderType(); + + // Snapshot the GQA's original input edges (we will rewire them, except for slots 0/1). + auto gqa_input_edges = graph_utils::GraphEdge::GetNodeInputEdges(node); + auto gqa_output_edges = graph_utils::GraphEdge::GetNodeOutputEdges(node); + + // Remove all involved nodes (their input edges from elsewhere drop with them). + graph_utils::RemoveNodeOutputEdges(graph, node); + graph.RemoveNode(node.Index()); + graph_utils::RemoveNodeOutputEdges(graph, *q_reshape_outer); + graph.RemoveNode(q_reshape_outer->Index()); + graph_utils::RemoveNodeOutputEdges(graph, *q_sln); + graph.RemoveNode(q_sln->Index()); + graph_utils::RemoveNodeOutputEdges(graph, *q_reshape_inner); + graph.RemoveNode(q_reshape_inner->Index()); + graph_utils::RemoveNodeOutputEdges(graph, *k_reshape_outer); + graph.RemoveNode(k_reshape_outer->Index()); + graph_utils::RemoveNodeOutputEdges(graph, *k_sln); + graph.RemoveNode(k_sln->Index()); + graph_utils::RemoveNodeOutputEdges(graph, *k_reshape_inner); + graph.RemoveNode(k_reshape_inner->Index()); + + Node& fused = graph.AddNode(graph.GenerateNodeName(original_name + "_qknorm"), + "GroupQueryAttention", + "GroupQueryAttention with fused per-head Q/K RMSNorm", + new_inputs, + new_outputs, + &new_attrs, + kMSDomain); + fused.SetExecutionProviderType(original_ep); + + // Rewire upstream edges that fed the original GQA. Skip slots 0 and 1 (now driven by + // the projection outputs which are still produced by their upstream nodes; the + // graph.AddNode + matching NodeArg name will let the graph's edge resolver re-attach + // those producer edges automatically when Resolve() runs, but we add them explicitly + // for safety). + for (const auto& e : gqa_input_edges) { + if (e.dst_arg_index == 0 || e.dst_arg_index == 1) { + continue; + } + graph.AddEdge(e.src_node, fused.Index(), e.src_arg_index, e.dst_arg_index); + } + // Add explicit edges for the new query/key inputs from the projection nodes. + if (Node* q_proj_node = graph.GetMutableProducerNode(q_projection_arg->Name())) { + const int src_idx = graph_utils::GetNodeOutputIndexFromOutputName(*q_proj_node, q_projection_arg->Name()); + graph.AddEdge(q_proj_node->Index(), fused.Index(), src_idx, 0); + } + if (Node* k_proj_node = graph.GetMutableProducerNode(k_projection_arg->Name())) { + const int src_idx = graph_utils::GetNodeOutputIndexFromOutputName(*k_proj_node, k_projection_arg->Name()); + graph.AddEdge(k_proj_node->Index(), fused.Index(), src_idx, 1); + } + // Rewire downstream edges from the original GQA outputs. + for (const auto& e : gqa_output_edges) { + graph.AddEdge(fused.Index(), e.dst_node, e.src_arg_index, e.dst_arg_index); + } + + modified = true; + } + + return Status::OK(); +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/group_query_attention_pre_norm_fusion.h b/onnxruntime/core/optimizer/group_query_attention_pre_norm_fusion.h new file mode 100644 index 0000000000000..b69199bb5324d --- /dev/null +++ b/onnxruntime/core/optimizer/group_query_attention_pre_norm_fusion.h @@ -0,0 +1,47 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/optimizer/graph_transformer.h" + +namespace onnxruntime { + +/** +@Class GroupQueryAttentionPreNormFusion + +Folds the Qwen3-style per-head Q/K RMSNorm prologue into the GroupQueryAttention +node by adding optional q_norm_weight and k_norm_weight inputs (slots 14 and 15) +and a qk_norm_epsilon attribute. The transform looks for the following pattern +on inputs 0 (query) and 1 (key) of an unfused GroupQueryAttention node: + + Q_proj_out -> Reshape[*,*,head_size] + -> SimplifiedLayerNormalization(weight = q_norm_weight) + -> Reshape[*,*,num_heads * head_size] + -> GQA[input 0] + + K_proj_out -> Reshape[*,*,head_size] + -> SimplifiedLayerNormalization(weight = k_norm_weight) + -> Reshape[*,*,kv_num_heads * head_size] + -> GQA[input 1] + +When matched, the six Reshape/SLN nodes are removed and the pre-norm Q and K +projections feed GQA directly. The kernel is responsible for applying the RMS +norm internally (currently the WebGPU EP). + +Only fires for execution providers passed in `compatible_execution_providers`. +At present this fusion is registered for the WebGPU EP only, because the +in-kernel norm path is currently implemented there. The CPU, CUDA, and JSEP +GroupQueryAttention kernels reject q_norm_weight / k_norm_weight inputs. +*/ +class GroupQueryAttentionPreNormFusion : public GraphTransformer { + public: + explicit GroupQueryAttentionPreNormFusion( + const InlinedHashSet& compatible_execution_providers = {}) noexcept + : GraphTransformer("GroupQueryAttentionPreNormFusion", compatible_execution_providers) { + } + + Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h index 636f185eda422..39cce7941a15b 100644 --- a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h +++ b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h @@ -159,4 +159,4 @@ constexpr const char GeluExpr[] = "0.5 * a * (1.0 + erf_v(a * 0.7071067811865475 constexpr const char FastGeluExpr[] = "a * (0.5 + 0.5 * tanh_v(a * (0.035677408136300125 * a * a + 0.7978845608028654)))"; } // namespace webgpu -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/nn/layer_norm.cc b/onnxruntime/core/providers/webgpu/nn/layer_norm.cc index 7d4ae8c2197ff..3afedea30adaf 100644 --- a/onnxruntime/core/providers/webgpu/nn/layer_norm.cc +++ b/onnxruntime/core/providers/webgpu/nn/layer_norm.cc @@ -162,8 +162,6 @@ Status LayerNorm::ComputeInternal(onnxruntime::webgpu::ComputeContex const size_t axis = NormalizeAxis(axis_, x_shape.NumDimensions()); const uint32_t norm_count = onnxruntime::narrow(x_shape.SizeToDimension(axis)); const int64_t norm_size = x_shape.SizeFromDimension(axis); - const int components = GetMaxComponents(norm_size); - const uint32_t norm_size_vectorized = onnxruntime::narrow((norm_size + components - 1) / components); const auto scale_size = scale->Shape().Size(); const auto bias_size = (bias) ? bias->Shape().Size() : 0; @@ -192,6 +190,28 @@ Status LayerNorm::ComputeInternal(onnxruntime::webgpu::ComputeContex return Status::OK(); } + return RunLayerNormProgram(context, x, scale, bias, epsilon_, norm_count, norm_size, + simplified, y, mean, inv_std_dev); +} + +Status RunLayerNormProgram(ComputeContext& context, + const Tensor* x, + const Tensor* scale, + const Tensor* bias, + float epsilon, + uint32_t norm_count, + int64_t norm_size, + bool simplified, + Tensor* y, + Tensor* mean, + Tensor* inv_std_dev) { + if (x->Shape().Size() == 0) { + return Status::OK(); + } + + const int components = GetMaxComponents(norm_size); + const uint32_t norm_size_vectorized = onnxruntime::narrow((norm_size + components - 1) / components); + // Check if we should use split norm dimension optimization const bool split_norm_dim = norm_size % 512 == 0 && norm_count == 1; @@ -215,7 +235,7 @@ Status LayerNorm::ComputeInternal(onnxruntime::webgpu::ComputeContex {static_cast(norm_size_vectorized)}, }) .AddUniformVariables({ - {static_cast(epsilon_)}, + {static_cast(epsilon)}, }); if (split_norm_dim) { diff --git a/onnxruntime/core/providers/webgpu/nn/layer_norm.h b/onnxruntime/core/providers/webgpu/nn/layer_norm.h index 112b152d37130..a6323dc7721d4 100644 --- a/onnxruntime/core/providers/webgpu/nn/layer_norm.h +++ b/onnxruntime/core/providers/webgpu/nn/layer_norm.h @@ -56,5 +56,21 @@ class LayerNorm final : public WebGpuKernel { int64_t stash_type_; }; +// Configures and dispatches a LayerNormProgram. Centralizes the program-setup logic +// (uniform variables, components, split_norm_dim heuristic, workgroup sizing) so callers +// other than the LayerNorm kernel (e.g. fused MatMulNBits ops) do not need to duplicate it. +// `bias`, `mean` and `inv_std_dev` may be nullptr. +Status RunLayerNormProgram(ComputeContext& context, + const Tensor* x, + const Tensor* scale, + const Tensor* bias, + float epsilon, + uint32_t norm_count, + int64_t norm_size, + bool simplified, + Tensor* y, + Tensor* mean, + Tensor* inv_std_dev); + } // namespace webgpu } // namespace onnxruntime diff --git a/onnxruntime/test/optimizer/group_query_attention_pre_norm_fusion_test.cc b/onnxruntime/test/optimizer/group_query_attention_pre_norm_fusion_test.cc new file mode 100644 index 0000000000000..4f3deb2306f38 --- /dev/null +++ b/onnxruntime/test/optimizer/group_query_attention_pre_norm_fusion_test.cc @@ -0,0 +1,333 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include + +#include "core/graph/node_attr_utils.h" +#include "core/optimizer/group_query_attention_pre_norm_fusion.h" +#include "core/optimizer/utils.h" + +#include "test/util/include/asserts.h" +#include "test/unittest_util/framework_test_utils.h" +#include "test/unittest_util/graph_transform_test_builder.h" +#include "test/optimizer/graph_transform_test_fixture.h" + +#include "gtest/gtest.h" + +namespace onnxruntime { +namespace test { + +#if !defined(DISABLE_CONTRIB_OPS) + +namespace { + +// Small geometry that exercises the Q/K post-norm pattern without needing real GPU work. +constexpr int64_t kBatch = 1; +constexpr int64_t kSeq = 1; +constexpr int64_t kNumHeads = 2; +constexpr int64_t kKvNumHeads = 1; +constexpr int64_t kHeadSize = 4; +constexpr int64_t kQHidden = kNumHeads * kHeadSize; +constexpr int64_t kKvHidden = kKvNumHeads * kHeadSize; +constexpr int64_t kMaxSeq = 8; + +void SetWebGpu(Node& node) { node.SetExecutionProviderType(kWebGpuExecutionProvider); } + +// Builds: [Reshape -> SimplifiedLayerNormalization -> Reshape] on Q and K, feeding a +// GroupQueryAttention node. V goes straight into GQA. The pattern is configured via +// BuildOptions so individual tests can flip a single attribute / shape / epsilon to +// exercise each gate. +struct BuildOptions { + float q_epsilon = 1e-6f; + float k_epsilon = 1e-6f; + // If true, the inner reshape on the K side targets a different last-dim than head_size + // so the matcher must reject it. + bool break_k_inner_reshape_shape = false; + // If true, the q_norm_weight initializer is given a non-1D shape so the matcher must + // reject it. + bool break_q_norm_weight_shape = false; + // GQA do_rotary attribute. The WebGPU fused prologue only supports do_rotary=1, so the + // optimizer must skip the rewrite when this is 0. + int64_t do_rotary = 1; + // If true, drop the K input from the GQA node (slot 1 empty), simulating the packed-QKV + // form. The optimizer must skip the rewrite in that case. + bool packed_qkv = false; + // If true, pre-populate the GQA node's slot 14 with a q_norm_weight initializer so the + // optimizer treats the node as already fused and skips it. + bool pre_fused = false; +}; + +void BuildQwenQkPostNormPattern(ModelTestBuilder& builder, const BuildOptions& opts) { + // Projection inputs (post linear projection, pre norm). + NodeArg* q_proj = builder.MakeInput( + std::vector{kBatch, kSeq, kQHidden}, MLFloat16(-1.0f), MLFloat16(1.0f)); + NodeArg* k_proj = builder.MakeInput( + std::vector{kBatch, kSeq, kKvHidden}, MLFloat16(-1.0f), MLFloat16(1.0f)); + NodeArg* v_proj = builder.MakeInput( + std::vector{kBatch, kSeq, kKvHidden}, MLFloat16(-1.0f), MLFloat16(1.0f)); + + // GQA cache + control inputs. + NodeArg* past_key = builder.MakeInput( + std::vector{kBatch, kKvNumHeads, kMaxSeq, kHeadSize}, MLFloat16(0.0f), MLFloat16(0.0f)); + NodeArg* past_value = builder.MakeInput( + std::vector{kBatch, kKvNumHeads, kMaxSeq, kHeadSize}, MLFloat16(0.0f), MLFloat16(0.0f)); + // Note: ModelTestBuilder::MakeInput(shape, min, max) calls Uniform(min, max - 1) + // internally, which asserts on min == max. Use the explicit-data overload instead. + NodeArg* seqlens_k = builder.MakeInput(std::vector{kBatch}, std::vector{0}); + NodeArg* total_seq_len = builder.MakeInput(std::vector{1}, std::vector{1}); + + // Norm weight initializers: [head_size]. (Or non-1D when forcing a shape mismatch.) + std::vector q_norm_weight_shape = + opts.break_q_norm_weight_shape ? std::vector{1, kHeadSize} : std::vector{kHeadSize}; + NodeArg* q_norm_weight = builder.MakeInitializer(q_norm_weight_shape, MLFloat16(1.0f), MLFloat16(1.0f)); + NodeArg* k_norm_weight = builder.MakeInitializer({kHeadSize}, MLFloat16(1.0f), MLFloat16(1.0f)); + + // Reshape "shape" initializers. + NodeArg* reshape_to_per_head_q = builder.MakeInitializer({4}, {kBatch, kSeq, kNumHeads, kHeadSize}); + const int64_t k_inner_last_dim = opts.break_k_inner_reshape_shape ? (kHeadSize * 2) : kHeadSize; + NodeArg* reshape_to_per_head_k = + builder.MakeInitializer({4}, {kBatch, kSeq, kKvNumHeads, k_inner_last_dim}); + NodeArg* reshape_to_q_hidden = builder.MakeInitializer({3}, {kBatch, kSeq, kQHidden}); + NodeArg* reshape_to_kv_hidden = builder.MakeInitializer({3}, {kBatch, kSeq, kKvHidden}); + + // Q-side chain. + NodeArg* q_inner_reshape_out = builder.MakeIntermediate( + std::vector{kBatch, kSeq, kNumHeads, kHeadSize}); + NodeArg* q_normed = builder.MakeIntermediate( + std::vector{kBatch, kSeq, kNumHeads, kHeadSize}); + NodeArg* q_outer_reshape_out = builder.MakeIntermediate( + std::vector{kBatch, kSeq, kQHidden}); + + Node& q_inner_reshape = builder.AddNode("Reshape", {q_proj, reshape_to_per_head_q}, {q_inner_reshape_out}); + Node& q_sln = builder.AddNode("SimplifiedLayerNormalization", {q_inner_reshape_out, q_norm_weight}, {q_normed}); + q_sln.AddAttribute("axis", static_cast(-1)); + q_sln.AddAttribute("epsilon", opts.q_epsilon); + Node& q_outer_reshape = builder.AddNode("Reshape", {q_normed, reshape_to_q_hidden}, {q_outer_reshape_out}); + + // K-side chain. + NodeArg* k_inner_reshape_out = builder.MakeIntermediate( + std::vector{kBatch, kSeq, kKvNumHeads, k_inner_last_dim}); + NodeArg* k_normed = builder.MakeIntermediate( + std::vector{kBatch, kSeq, kKvNumHeads, k_inner_last_dim}); + NodeArg* k_outer_reshape_out = builder.MakeIntermediate( + std::vector{kBatch, kSeq, kKvHidden}); + + Node& k_inner_reshape = builder.AddNode("Reshape", {k_proj, reshape_to_per_head_k}, {k_inner_reshape_out}); + Node& k_sln = builder.AddNode("SimplifiedLayerNormalization", {k_inner_reshape_out, k_norm_weight}, {k_normed}); + k_sln.AddAttribute("axis", static_cast(-1)); + k_sln.AddAttribute("epsilon", opts.k_epsilon); + Node& k_outer_reshape = builder.AddNode("Reshape", {k_normed, reshape_to_kv_hidden}, {k_outer_reshape_out}); + + // GQA outputs. + NodeArg* gqa_out = builder.MakeOutput(std::vector{kBatch, kSeq, kQHidden}); + NodeArg* present_key = builder.MakeOutput( + std::vector{kBatch, kKvNumHeads, kMaxSeq, kHeadSize}); + NodeArg* present_value = builder.MakeOutput( + std::vector{kBatch, kKvNumHeads, kMaxSeq, kHeadSize}); + + // Build the GQA input list. The packed_qkv variant drops K (slot 1) and V (slot 2). + NodeArg& empty_arg = builder.graph_.GetOrCreateNodeArg("", nullptr); + std::vector gqa_inputs; + gqa_inputs.push_back(q_outer_reshape_out); + gqa_inputs.push_back(opts.packed_qkv ? &empty_arg : k_outer_reshape_out); + gqa_inputs.push_back(opts.packed_qkv ? &empty_arg : v_proj); + gqa_inputs.push_back(past_key); + gqa_inputs.push_back(past_value); + gqa_inputs.push_back(seqlens_k); + gqa_inputs.push_back(total_seq_len); + + if (opts.pre_fused) { + // Pad slots 7..13 with empty args, then place a real norm weight in slot 14. + for (int i = 7; i < 14; ++i) { + gqa_inputs.push_back(&empty_arg); + } + NodeArg* preexisting_q_norm = + builder.MakeInitializer({kHeadSize}, MLFloat16(1.0f), MLFloat16(1.0f)); + gqa_inputs.push_back(preexisting_q_norm); + } + + Node& gqa = builder.AddNode("GroupQueryAttention", + gqa_inputs, + {gqa_out, present_key, present_value}, + kMSDomain); + gqa.AddAttribute("num_heads", static_cast(kNumHeads)); + gqa.AddAttribute("kv_num_heads", static_cast(kKvNumHeads)); + gqa.AddAttribute("do_rotary", opts.do_rotary); + + SetWebGpu(q_inner_reshape); + SetWebGpu(q_sln); + SetWebGpu(q_outer_reshape); + SetWebGpu(k_inner_reshape); + SetWebGpu(k_sln); + SetWebGpu(k_outer_reshape); + SetWebGpu(gqa); +} + +Status CheckFusedGraph(Graph& graph) { + const auto op_to_count = CountOpsInGraph(graph); + if (OpCount(op_to_count, "com.microsoft.GroupQueryAttention") != 1 || + OpCount(op_to_count, "SimplifiedLayerNormalization") != 0 || + OpCount(op_to_count, "Reshape") != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "Unexpected op counts after GroupQueryAttentionPreNormFusion: ", + "GQA=", OpCount(op_to_count, "com.microsoft.GroupQueryAttention"), + " SLN=", OpCount(op_to_count, "SimplifiedLayerNormalization"), + " Reshape=", OpCount(op_to_count, "Reshape")); + } + + for (const auto& node : graph.Nodes()) { + if (node.OpType() != "GroupQueryAttention") continue; + ORT_RETURN_IF_NOT(node.InputDefs().size() >= 16, "Fused GQA must expose 16 input slots."); + + const auto* q_norm = node.InputDefs()[14]; + const auto* k_norm = node.InputDefs()[15]; + ORT_RETURN_IF_NOT(q_norm != nullptr && q_norm->Exists(), "q_norm_weight (slot 14) missing."); + ORT_RETURN_IF_NOT(k_norm != nullptr && k_norm->Exists(), "k_norm_weight (slot 15) missing."); + + const auto& attrs = node.GetAttributes(); + auto eps_it = attrs.find("qk_norm_epsilon"); + ORT_RETURN_IF_NOT(eps_it != attrs.end(), "qk_norm_epsilon attribute missing."); + ORT_RETURN_IF_NOT(std::abs(eps_it->second.f() - 1e-6f) < 1e-9f, "qk_norm_epsilon value mismatch."); + } + return Status::OK(); +} + +Status CheckUnfusedGraph(Graph& graph) { + const auto op_to_count = CountOpsInGraph(graph); + if (OpCount(op_to_count, "com.microsoft.GroupQueryAttention") != 1 || + OpCount(op_to_count, "SimplifiedLayerNormalization") != 2 || + OpCount(op_to_count, "Reshape") != 4) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Negative test: graph was fused unexpectedly."); + } + for (const auto& node : graph.Nodes()) { + if (node.OpType() != "GroupQueryAttention") continue; + if (node.InputDefs().size() >= 15) { + const auto* q_norm = node.InputDefs()[14]; + ORT_RETURN_IF_NOT(q_norm == nullptr || !q_norm->Exists(), + "Negative test: q_norm_weight should not be wired."); + } + } + return Status::OK(); +} + +} // namespace + +// Helper: build the transformer registered for the WebGPU EP only (matches production). +std::unique_ptr MakeWebGpuTransformer() { + return std::make_unique( + InlinedHashSet{kWebGpuExecutionProvider}); +} + +TEST_F(GraphTransformationTests, GroupQueryAttentionPreNormFusionFusesQwenPattern) { + auto build = [](ModelTestBuilder& builder) { BuildQwenQkPostNormPattern(builder, BuildOptions{}); }; + ASSERT_STATUS_OK(TestGraphTransformer( + build, /*opset_version=*/21, *logger_, MakeWebGpuTransformer(), + TransformerLevel::Level2, /*steps=*/1, nullptr, CheckFusedGraph)); +} + +TEST_F(GraphTransformationTests, GroupQueryAttentionPreNormFusionRejectsEpsilonMismatch) { + BuildOptions opts; + opts.q_epsilon = 1e-6f; + opts.k_epsilon = 1e-5f; + auto build = [opts](ModelTestBuilder& builder) { BuildQwenQkPostNormPattern(builder, opts); }; + ASSERT_STATUS_OK(TestGraphTransformer( + build, /*opset_version=*/21, *logger_, MakeWebGpuTransformer(), + TransformerLevel::Level2, /*steps=*/1, nullptr, CheckUnfusedGraph)); +} + +TEST_F(GraphTransformationTests, GroupQueryAttentionPreNormFusionRejectsBadInnerReshape) { + BuildOptions opts; + opts.break_k_inner_reshape_shape = true; + auto build = [opts](ModelTestBuilder& builder) { BuildQwenQkPostNormPattern(builder, opts); }; + ASSERT_STATUS_OK(TestGraphTransformer( + build, /*opset_version=*/21, *logger_, MakeWebGpuTransformer(), + TransformerLevel::Level2, /*steps=*/1, nullptr, CheckUnfusedGraph)); +} + +TEST_F(GraphTransformationTests, GroupQueryAttentionPreNormFusionRejectsNon1DNormWeight) { + BuildOptions opts; + opts.break_q_norm_weight_shape = true; + auto build = [opts](ModelTestBuilder& builder) { BuildQwenQkPostNormPattern(builder, opts); }; + ASSERT_STATUS_OK(TestGraphTransformer( + build, /*opset_version=*/21, *logger_, MakeWebGpuTransformer(), + TransformerLevel::Level2, /*steps=*/1, nullptr, CheckUnfusedGraph)); +} + +TEST_F(GraphTransformationTests, GroupQueryAttentionPreNormFusionSkipsCpuEp) { + // Build the pattern but assign all nodes to CPU EP. The fusion is gated to WebGPU only, + // so the graph must remain unfused. + auto build = [](ModelTestBuilder& builder) { + BuildQwenQkPostNormPattern(builder, BuildOptions{}); + for (auto& node : builder.graph_.Nodes()) { + const_cast(node).SetExecutionProviderType(kCpuExecutionProvider); + } + }; + ASSERT_STATUS_OK(TestGraphTransformer( + build, /*opset_version=*/21, *logger_, MakeWebGpuTransformer(), + TransformerLevel::Level2, /*steps=*/1, nullptr, CheckUnfusedGraph)); +} + +TEST_F(GraphTransformationTests, GroupQueryAttentionPreNormFusionSkipsJsEp) { + // JSEP does not implement the fused per-head Q/K RMSNorm prologue, so the optimizer + // (which we now register for WebGPU only) must leave JSEP-assigned graphs alone. + auto build = [](ModelTestBuilder& builder) { + BuildQwenQkPostNormPattern(builder, BuildOptions{}); + for (auto& node : builder.graph_.Nodes()) { + const_cast(node).SetExecutionProviderType(kJsExecutionProvider); + } + }; + ASSERT_STATUS_OK(TestGraphTransformer( + build, /*opset_version=*/21, *logger_, MakeWebGpuTransformer(), + TransformerLevel::Level2, /*steps=*/1, nullptr, CheckUnfusedGraph)); +} + +TEST_F(GraphTransformationTests, GroupQueryAttentionPreNormFusionSkipsWhenDoRotaryDisabled) { + // The WebGPU fused prologue requires do_rotary=1; the optimizer must skip otherwise so + // the runtime guard never trips. + BuildOptions opts; + opts.do_rotary = 0; + auto build = [opts](ModelTestBuilder& builder) { BuildQwenQkPostNormPattern(builder, opts); }; + ASSERT_STATUS_OK(TestGraphTransformer( + build, /*opset_version=*/21, *logger_, MakeWebGpuTransformer(), + TransformerLevel::Level2, /*steps=*/1, nullptr, CheckUnfusedGraph)); +} + +TEST_F(GraphTransformationTests, GroupQueryAttentionPreNormFusionSkipsPackedQkv) { + // Packed-QKV form leaves slots 1 and 2 empty; the WebGPU fused prologue does not support + // it, so the optimizer must skip the rewrite. + BuildOptions opts; + opts.packed_qkv = true; + auto build = [opts](ModelTestBuilder& builder) { BuildQwenQkPostNormPattern(builder, opts); }; + ASSERT_STATUS_OK(TestGraphTransformer( + build, /*opset_version=*/21, *logger_, MakeWebGpuTransformer(), + TransformerLevel::Level2, /*steps=*/1, nullptr, CheckUnfusedGraph)); +} + +TEST_F(GraphTransformationTests, GroupQueryAttentionPreNormFusionSkipsAlreadyFusedNode) { + // If the GQA node already exposes a q_norm_weight (slot 14) input the optimizer must + // treat it as already fused and leave the surrounding SLN/Reshape ops in place. The + // standard CheckUnfusedGraph helper rejects any wiring at slot 14, so use a custom + // checker that only verifies the surrounding ops weren't removed. + BuildOptions opts; + opts.pre_fused = true; + auto build = [opts](ModelTestBuilder& builder) { BuildQwenQkPostNormPattern(builder, opts); }; + auto check = [](Graph& graph) -> Status { + const auto op_to_count = CountOpsInGraph(graph); + ORT_RETURN_IF_NOT(OpCount(op_to_count, "com.microsoft.GroupQueryAttention") == 1, + "Already-fused test: GQA count changed."); + ORT_RETURN_IF_NOT(OpCount(op_to_count, "SimplifiedLayerNormalization") == 2, + "Already-fused test: SLN ops were removed."); + ORT_RETURN_IF_NOT(OpCount(op_to_count, "Reshape") == 4, + "Already-fused test: Reshape ops were removed."); + return Status::OK(); + }; + ASSERT_STATUS_OK(TestGraphTransformer( + build, /*opset_version=*/21, *logger_, MakeWebGpuTransformer(), + TransformerLevel::Level2, /*steps=*/1, nullptr, check)); +} + +#endif // !defined(DISABLE_CONTRIB_OPS) + +} // namespace test +} // namespace onnxruntime