Add ONNX Runtime GQA-style SDPA benchmark#18716
Add ONNX Runtime GQA-style SDPA benchmark#18716kimishpatel wants to merge 1 commit intogh/kimishpatel/236/basefrom
Conversation
Ports the attention algorithm from onnxruntime's gqa_attention_base.h to enable direct performance comparison against ET's flash attention and the existing standard SDPA benchmark. Key differences from standard SDPA: scale baked into GEMM alpha (saves a scaling pass), scores buffer padded to max_seq_len columns (matching ONNX's present_buffer_sequence_length), narrow softmax over valid causal window only (zeros elsewhere, skips exp on masked positions), and output in [B,S,Hq,D] with stride Hq*D matching ONNX's interleaved output format. Validation tests confirm ONNX GQA matches ET custom_sdpa_out within float32 tolerance. Authored with Claude. Differential Revision: [D99677677](https://our.internmc.facebook.com/intern/diff/D99677677/) [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18716
Note: Links to docs will display an error until the docs builds have been completed. ❌ 127 New FailuresAs of commit 41e6c73 with merge base fb1618e ( NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
|
submitted by accident, not meant to land immedidately |
Stack from ghstack (oldest at bottom):
Ports the attention algorithm from onnxruntime's gqa_attention_base.h
to enable direct performance comparison against ET's flash attention
and the existing standard SDPA benchmark.
Key differences from standard SDPA: scale baked into GEMM alpha
(saves a scaling pass), scores buffer padded to max_seq_len columns
(matching ONNX's present_buffer_sequence_length), narrow softmax
over valid causal window only (zeros elsewhere, skips exp on masked
positions), and output in [B,S,Hq,D] with stride Hq*D matching
ONNX's interleaved output format.
Validation tests confirm ONNX GQA matches ET custom_sdpa_out within
float32 tolerance.
Authored with Claude.
Differential Revision: D99677677