Skip to content

Commit 87e9644

Browse files
committed
Address Copilot review comments on PR #28484
- [WebGPU] Validate q/k_norm_weight is 1-D of length head_size in the GQA kernel so a hand-authored model with the wrong shape fails with INVALID_ARGUMENT instead of reading wrong offsets. - [Optimizer] Require SimplifiedLayerNormalization input/scale/output element types to match before fusing, since the fused GQA input slots reuse the projection's element type (T) and a mixed-type SLN would change the node's type constraints. - [JSEP] Reject the GQA node when q_norm_weight or k_norm_weight is present regardless of rank (including scalars), instead of only checking dims.length > 0.
1 parent d5e208b commit 87e9644

3 files changed

Lines changed: 42 additions & 5 deletions

File tree

js/web/lib/wasm/jsep/webgpu/ops/group-query-attention.ts

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -331,11 +331,8 @@ export const groupQueryAttention = (context: ComputeContext, attributes: GroupQu
331331
// q_norm_weight (input 14) / k_norm_weight (input 15) are emitted by the WebGPU-only
332332
// GroupQueryAttentionPreNormFusion optimizer pass. JSEP does not implement the fused
333333
// per-head Q/K RMS normalization prologue, so reject the node if either input is present
334-
// rather than silently dropping the normalization.
335-
if (
336-
(context.inputs.length > 14 && context.inputs[14] && context.inputs[14].dims.length > 0) ||
337-
(context.inputs.length > 15 && context.inputs[15] && context.inputs[15].dims.length > 0)
338-
) {
334+
// (regardless of rank, including scalars) rather than silently dropping the normalization.
335+
if ((context.inputs.length > 14 && context.inputs[14]) || (context.inputs.length > 15 && context.inputs[15])) {
339336
throw new Error(
340337
'GroupQueryAttention (JSEP): q_norm_weight / k_norm_weight inputs are not supported. ' +
341338
'The per-head Q/K RMS normalization prologue is implemented only on the native WebGPU EP.',

onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,21 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext&
261261
WebgpuAttentionParameters parameters(params);
262262
ORT_RETURN_IF(has_qk_norm && parameters.is_packed_qkv_,
263263
"GroupQueryAttention: q_norm_weight / k_norm_weight are not supported when QKV is packed.");
264+
if (has_qk_norm) {
265+
// The fused prologue indexes q/k_norm_weight as a 1-D tensor of length head_size. Validate
266+
// shape here so a hand-authored model with a wrong shape fails with INVALID_ARGUMENT instead
267+
// of silently reading the wrong offsets (or out of bounds).
268+
const auto& q_norm_shape = q_norm_weight->Shape();
269+
ORT_RETURN_IF_NOT(q_norm_shape.NumDimensions() == 1 &&
270+
q_norm_shape[0] == static_cast<int64_t>(parameters.head_size_),
271+
"GroupQueryAttention: q_norm_weight must be a 1-D tensor of shape [head_size=",
272+
parameters.head_size_, "], got ", q_norm_shape.ToString(), ".");
273+
const auto& k_norm_shape = k_norm_weight->Shape();
274+
ORT_RETURN_IF_NOT(k_norm_shape.NumDimensions() == 1 &&
275+
k_norm_shape[0] == static_cast<int64_t>(parameters.head_size_),
276+
"GroupQueryAttention: k_norm_weight must be a 1-D tensor of shape [head_size=",
277+
parameters.head_size_, "], got ", k_norm_shape.ToString(), ".");
278+
}
264279
TensorShapeVector output_shape(3);
265280
output_shape[0] = static_cast<int64_t>(parameters.batch_size_);
266281
output_shape[1] = static_cast<int64_t>(parameters.sequence_length_);

onnxruntime/core/optimizer/group_query_attention_pre_norm_fusion.cc

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,31 @@ bool MatchPreNormReshapeChain(Graph& graph,
124124
return false;
125125
}
126126

127+
// SimplifiedLayerNormalization permits its input (T), scale (V) and output (T) to use different
128+
// element types. The fused GroupQueryAttention input slots reuse the projection's element type
129+
// (T), so we can only fuse when scale and output also use T -- otherwise the rewrite would
130+
// change the node's type constraints and produce a semantically different graph. Require all
131+
// three to match before fusing.
132+
auto get_elem_type = [](const NodeArg* arg) -> int32_t {
133+
if (arg == nullptr) {
134+
return ONNX_NAMESPACE::TensorProto::UNDEFINED;
135+
}
136+
const auto* type_proto = arg->TypeAsProto();
137+
if (type_proto == nullptr || !type_proto->has_tensor_type() ||
138+
!type_proto->tensor_type().has_elem_type()) {
139+
return ONNX_NAMESPACE::TensorProto::UNDEFINED;
140+
}
141+
return type_proto->tensor_type().elem_type();
142+
};
143+
const int32_t sln_input_elem_type = get_elem_type(sln->InputDefs()[0]);
144+
const int32_t sln_scale_elem_type = get_elem_type(sln->InputDefs()[1]);
145+
const int32_t sln_output_elem_type = get_elem_type(sln->OutputDefs()[0]);
146+
if (sln_input_elem_type == ONNX_NAMESPACE::TensorProto::UNDEFINED ||
147+
sln_input_elem_type != sln_scale_elem_type ||
148+
sln_input_elem_type != sln_output_elem_type) {
149+
return false;
150+
}
151+
127152
// Norm weight must be an initializer of shape [head_size].
128153
NodeArg* norm_weight_arg = sln->MutableInputDefs()[1];
129154
const ONNX_NAMESPACE::TensorProto* norm_weight_tensor =

0 commit comments

Comments
 (0)