[Executorch] Add non-flash SDPA for decode#18648
[Executorch] Add non-flash SDPA for decode#18648kimishpatel wants to merge 2 commits intogh/kimishpatel/221/basefrom
Conversation
Add cpu_sdpa template function in op_sdpa_impl.h that provides a simpler SDPA implementation using standard GEMM (no tiling). This is useful as a baseline and for cases where flash attention is not optimal. The implementation uses a single SeqDim parameter for all tensors and supports causal masking, attention masks, GQA, and multi-threading. During decode (seq_len == 1), the tiled flash attention implementation has unnecessary overhead from its blocking/tiling logic. The simpler unfused SDPA path using direct GEMM is more efficient for single-query attention, yielding ~25-30% decode throughput improvement on S25 (41 -> 53 tok/s for 1.4B parameter model). This makes cpu_sdpa always available (previously gated behind ET_USE_UNFUSED_SDPA) and dispatches to it when seq_len == 1 and inputs are not quantized. Prefill continues to use flash attention. Differential Revision: [D96044318](https://our.internmc.facebook.com/intern/diff/D96044318/) [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18648
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New Failure, 2 Cancelled JobsAs of commit beb7b11 with merge base fb1618e ( NEW FAILURE - The following job has failed:
CANCELLED JOBS - The following jobs were cancelled. Please retry:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
digantdesai
left a comment
There was a problem hiding this comment.
Review automatically exported from Phabricator review in Meta.
Add cpu_sdpa template function in op_sdpa_impl.h that provides a simpler SDPA implementation using standard GEMM (no tiling). This is useful as a baseline and for cases where flash attention is not optimal. The implementation uses a single SeqDim parameter for all tensors and supports causal masking, attention masks, GQA, and multi-threading. During decode (seq_len == 1), the tiled flash attention implementation has unnecessary overhead from its blocking/tiling logic. The simpler unfused SDPA path using direct GEMM is more efficient for single-query attention, yielding ~25-30% decode throughput improvement on S25 (41 -> 53 tok/s for 1.4B parameter model). This makes cpu_sdpa always available (previously gated behind ET_USE_UNFUSED_SDPA) and dispatches to it when seq_len == 1 and inputs are not quantized. Prefill continues to use flash attention. Differential Revision: [D96044318](https://our.internmc.facebook.com/intern/diff/D96044318/) [ghstack-poisoned]
Stack from ghstack (oldest at bottom):
Add cpu_sdpa template function in op_sdpa_impl.h that provides a
simpler SDPA implementation using standard GEMM (no tiling). This is
useful as a baseline and for cases where flash attention is not optimal.
The implementation uses a single SeqDim parameter for all tensors and
supports causal masking, attention masks, GQA, and multi-threading.
During decode (seq_len == 1), the tiled flash attention implementation
has unnecessary overhead from its blocking/tiling logic. The simpler
unfused SDPA path using direct GEMM is more efficient for single-query
attention, yielding ~25-30% decode throughput improvement on S25
(41 -> 53 tok/s for 1.4B parameter model).
This makes cpu_sdpa always available (previously gated behind
ET_USE_UNFUSED_SDPA) and dispatches to it when seq_len == 1 and
inputs are not quantized. Prefill continues to use flash attention.
Differential Revision: D96044318