Add non-flash SDPA path gated by ET_USE_UNFUSED_SDPA#18717
Add non-flash SDPA path gated by ET_USE_UNFUSED_SDPA#18717kimishpatel wants to merge 1 commit intogh/kimishpatel/237/basefrom
Conversation
Benchmarks show ET's flash attention is 10-12% slower than non-tiled GEMM-based SDPA for decode (SeqLen=1) due to within-head tiling overhead (multiple small GEMMs + online softmax rescaling). This adds an alternative non-flash code path that computes full Q@K^T, standard softmax, then scores@V using two GEMM calls, gated by #ifdef ET_USE_UNFUSED_SDPA so it can be tested without disrupting the existing flash path. The new cpu_sdpa function reuses existing SeqDim, stride extraction, cpublas::gemm, and parallel_for infrastructure. Float-only (no quantized input support). Threading granularity is one (batch, head) per work unit. Differential Revision: [D99677685](https://our.internmc.facebook.com/intern/diff/D99677685/) [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18717
Note: Links to docs will display an error until the docs builds have been completed. ❌ 126 New Failures, 1 Unrelated FailureAs of commit 867ba3e with merge base fb1618e ( NEW FAILURES - The following jobs have failed:
FLAKY - The following job failed but was likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
|
submitted by accident, not meant to land immedidately |
Stack from ghstack (oldest at bottom):
Benchmarks show ET's flash attention is 10-12% slower than non-tiled
GEMM-based SDPA for decode (SeqLen=1) due to within-head tiling overhead
(multiple small GEMMs + online softmax rescaling). This adds an alternative
non-flash code path that computes full Q@K^T, standard softmax, then
scores@V using two GEMM calls, gated by #ifdef ET_USE_UNFUSED_SDPA so it
can be tested without disrupting the existing flash path.
The new cpu_sdpa function reuses existing SeqDim, stride extraction,
cpublas::gemm, and parallel_for infrastructure. Float-only (no quantized
input support). Threading granularity is one (batch, head) per work unit.
Differential Revision: D99677685