Skip to content

Commit 47c7c5e

Browse files
committed
Update base for Update on "Add ONNX Runtime GQA-style SDPA benchmark"
Add run_onnx_gqa_sdpa() which faithfully ports the algorithm from onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h: - Scale baked into GEMM alpha (no separate scaling pass) - Scores buffer padded to max_seq_len columns - Causal mask: zero out future positions, softmax on valid window only - Output always in [B, S, Hq, D] format Extends validation to verify ONNX GQA output matches custom_sdpa_out reference. Adds OnnxGQABenchFixture for benchmarking both layouts. Differential Revision: [D96044317](https://our.internmc.facebook.com/intern/diff/D96044317/) [ghstack-poisoned]
1 parent ff4a516 commit 47c7c5e

0 file changed

File tree

    0 commit comments

    Comments
     (0)