Add GEMM-based standard SDPA benchmark and fix custom_sdpa_out signature#18715
Add GEMM-based standard SDPA benchmark and fix custom_sdpa_out signature#18715kimishpatel wants to merge 1 commit intogh/kimishpatel/235/basefrom
Conversation
Adds a standalone GEMM-based (non-tiled) attention benchmark alongside the existing ET flash attention benchmark, allowing direct comparison of the two algorithms. The standard SDPA uses cblas_sgemm for Q@K^T and scores@V with 3-pass softmax, matching the approach used by ONNX Runtime's GQA operator. Both transposed [B,H,S,D] and standard [B,S,H,D] cache layouts are supported via BLAS leading dimension parameter. Validation tests run before benchmarks to ensure correctness against ET's custom_sdpa_out. Also fixes broken custom_sdpa_out calls to match the new 3-bool signature (is_seq_dim_2, is_k_seq_dim_2, is_v_seq_dim_2). Authored with Claude. Differential Revision: [D99677686](https://our.internmc.facebook.com/intern/diff/D99677686/) [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18715
Note: Links to docs will display an error until the docs builds have been completed. ❌ 128 New FailuresAs of commit e968c4a with merge base fb1618e ( NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
|
submitted by accident, not meant to land immedidately |
Stack from ghstack (oldest at bottom):
Adds a standalone GEMM-based (non-tiled) attention benchmark alongside
the existing ET flash attention benchmark, allowing direct comparison
of the two algorithms. The standard SDPA uses cblas_sgemm for Q@K^T
and scores@V with 3-pass softmax, matching the approach used by
ONNX Runtime's GQA operator.
Both transposed [B,H,S,D] and standard [B,S,H,D] cache layouts are
supported via BLAS leading dimension parameter. Validation tests run
before benchmarks to ensure correctness against ET's custom_sdpa_out.
Also fixes broken custom_sdpa_out calls to match the new 3-bool
signature (is_seq_dim_2, is_k_seq_dim_2, is_v_seq_dim_2).
Authored with Claude.
Differential Revision: D99677686