Add ONNX Runtime GQA-style SDPA benchmark#18647
Add ONNX Runtime GQA-style SDPA benchmark#18647kimishpatel wants to merge 2 commits intogh/kimishpatel/220/basefrom
Conversation
Add run_onnx_gqa_sdpa() which faithfully ports the algorithm from onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h: - Scale baked into GEMM alpha (no separate scaling pass) - Scores buffer padded to max_seq_len columns - Causal mask: zero out future positions, softmax on valid window only - Output always in [B, S, Hq, D] format Extends validation to verify ONNX GQA output matches custom_sdpa_out reference. Adds OnnxGQABenchFixture for benchmarking both layouts. Differential Revision: [D96044317](https://our.internmc.facebook.com/intern/diff/D96044317/) [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18647
Note: Links to docs will display an error until the docs builds have been completed. ❌ 3 New Failures, 3 Cancelled JobsAs of commit 8b9b1ae with merge base fb1618e ( NEW FAILURES - The following jobs have failed:
CANCELLED JOBS - The following jobs were cancelled. Please retry:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
digantdesai
left a comment
There was a problem hiding this comment.
Review automatically exported from Phabricator review in Meta.
Add run_onnx_gqa_sdpa() which faithfully ports the algorithm from onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h: - Scale baked into GEMM alpha (no separate scaling pass) - Scores buffer padded to max_seq_len columns - Causal mask: zero out future positions, softmax on valid window only - Output always in [B, S, Hq, D] format Extends validation to verify ONNX GQA output matches custom_sdpa_out reference. Adds OnnxGQABenchFixture for benchmarking both layouts. Differential Revision: [D96044317](https://our.internmc.facebook.com/intern/diff/D96044317/) [ghstack-poisoned]
Stack from ghstack (oldest at bottom):
Add run_onnx_gqa_sdpa() which faithfully ports the algorithm from
onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h:
Extends validation to verify ONNX GQA output matches custom_sdpa_out
reference. Adds OnnxGQABenchFixture for benchmarking both layouts.
Differential Revision: D96044317