Commit 47c7c5e
committed
Update base for Update on "Add ONNX Runtime GQA-style SDPA benchmark"
Add run_onnx_gqa_sdpa() which faithfully ports the algorithm from
onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h:
- Scale baked into GEMM alpha (no separate scaling pass)
- Scores buffer padded to max_seq_len columns
- Causal mask: zero out future positions, softmax on valid window only
- Output always in [B, S, Hq, D] format
Extends validation to verify ONNX GQA output matches custom_sdpa_out
reference. Adds OnnxGQABenchFixture for benchmarking both layouts.
Differential Revision: [D96044317](https://our.internmc.facebook.com/intern/diff/D96044317/)
[ghstack-poisoned]1 parent ff4a516 commit 47c7c5e
0 file changed
0 commit comments