Skip to content

Commit 83182ea

Browse files
feich-msclaude
andcommitted
Fix MHA CrossAttention regression: use num_heads_ for non-GQA paths
The internal present KV buffer shape must use num_heads_ for MHA (where kv_num_heads_ is 0) and kv_num_heads_ only for GQA. Using kv_num_heads_ unconditionally caused zero-sized buffers for MHA CrossAttention tests. Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
1 parent 9cf7ea3 commit 83182ea

1 file changed

Lines changed: 3 additions & 2 deletions

File tree

onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -423,14 +423,15 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co
423423
// Create present_key and present_value tensors if they are nullptr
424424
Tensor internal_present_key;
425425
Tensor internal_present_value;
426+
const int present_kv_heads = parameters.is_gqa_ ? parameters.kv_num_heads_ : parameters.num_heads_;
426427
if (present_key == nullptr) {
427-
TensorShapeVector present_kv_shape({parameters.batch_size_, parameters.kv_num_heads_,
428+
TensorShapeVector present_kv_shape({parameters.batch_size_, present_kv_heads,
428429
parameters.total_sequence_length_, parameters.head_size_});
429430
internal_present_key = context.CreateGPUTensor(Q->DataType(), TensorShape(present_kv_shape));
430431
present_key = &internal_present_key;
431432
}
432433
if (present_value == nullptr) {
433-
TensorShapeVector present_kv_shape({parameters.batch_size_, parameters.kv_num_heads_,
434+
TensorShapeVector present_kv_shape({parameters.batch_size_, present_kv_heads,
434435
parameters.total_sequence_length_, parameters.head_size_});
435436
internal_present_value = context.CreateGPUTensor(Q->DataType(), TensorShape(present_kv_shape));
436437
present_value = &internal_present_value;

0 commit comments

Comments
 (0)