Skip to content

Commit 068b686

Browse files
feich-msclaude
andcommitted
Fix MHA CrossAttention regression: use num_heads_ for non-GQA paths
The internal present KV buffer shape must use num_heads_ for MHA (where kv_num_heads_ is 0) and kv_num_heads_ only for GQA. Using kv_num_heads_ unconditionally caused zero-sized buffers for MHA CrossAttention tests. Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
1 parent 08bbc41 commit 068b686

1 file changed

Lines changed: 3 additions & 2 deletions

File tree

onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -425,14 +425,15 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co
425425
// Create present_key and present_value tensors if they are nullptr
426426
Tensor internal_present_key;
427427
Tensor internal_present_value;
428+
const int present_kv_heads = parameters.is_gqa_ ? parameters.kv_num_heads_ : parameters.num_heads_;
428429
if (present_key == nullptr) {
429-
TensorShapeVector present_kv_shape({parameters.batch_size_, parameters.kv_num_heads_,
430+
TensorShapeVector present_kv_shape({parameters.batch_size_, present_kv_heads,
430431
parameters.total_sequence_length_, parameters.head_size_});
431432
internal_present_key = context.CreateGPUTensor(Q->DataType(), TensorShape(present_kv_shape));
432433
present_key = &internal_present_key;
433434
}
434435
if (present_value == nullptr) {
435-
TensorShapeVector present_kv_shape({parameters.batch_size_, parameters.kv_num_heads_,
436+
TensorShapeVector present_kv_shape({parameters.batch_size_, present_kv_heads,
436437
parameters.total_sequence_length_, parameters.head_size_});
437438
internal_present_value = context.CreateGPUTensor(Q->DataType(), TensorShape(present_kv_shape));
438439
present_value = &internal_present_value;

0 commit comments

Comments
 (0)