Skip to content

Commit fbe8125

Browse files
committed
Update on "Add ONNX Runtime GQA-style SDPA benchmark"
Add run_onnx_gqa_sdpa() which faithfully ports the algorithm from onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h: - Scale baked into GEMM alpha (no separate scaling pass) - Scores buffer padded to max_seq_len columns - Causal mask: zero out future positions, softmax on valid window only - Output always in [B, S, Hq, D] format Extends validation to verify ONNX GQA output matches custom_sdpa_out reference. Adds OnnxGQABenchFixture for benchmarking both layouts. Differential Revision: [D96044317](https://our.internmc.facebook.com/intern/diff/D96044317/) [ghstack-poisoned]
2 parents 8b9b1ae + 4f9bb4d commit fbe8125

1 file changed

Lines changed: 247 additions & 93 deletions

File tree

0 commit comments

Comments
 (0)