Skip to content

Commit 412eeee

Browse files
committed
Update on "[Executorch] Add non-flash SDPA for decode"
Add cpu_sdpa template function in op_sdpa_impl.h that provides a simpler SDPA implementation using standard GEMM (no tiling). This is useful as a baseline and for cases where flash attention is not optimal. The implementation uses a single SeqDim parameter for all tensors and supports causal masking, attention masks, GQA, and multi-threading. During decode (seq_len == 1), the tiled flash attention implementation has unnecessary overhead from its blocking/tiling logic. The simpler unfused SDPA path using direct GEMM is more efficient for single-query attention, yielding ~25-30% decode throughput improvement on S25 (41 -> 53 tok/s for 1.4B parameter model). This makes cpu_sdpa always available (previously gated behind ET_USE_UNFUSED_SDPA) and dispatches to it when seq_len == 1 and inputs are not quantized. Prefill continues to use flash attention. Differential Revision: [D96044318](https://our.internmc.facebook.com/intern/diff/D96044318/) [ghstack-poisoned]
2 parents 6c887f8 + feae4cb commit 412eeee

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

.ci/scripts/test_lora.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ Okay, so I need to calculate 15% of 80."
139139
EXPECTED_QUANT_LORA_PREFIX="
140140
<|im_start|>user Calculate 15% of 80?<|im_end|><|im_start|>assistant
141141
To calculate 15% of 80, we can multiply 80 by 15/100.
142-
80 * 0.15 = 12.
142+
80 * 15/100 = 12.
143143
So, 15% of 80 is 12.
144144
#### 12
145145
The answer is: 12<|im_end|>"

0 commit comments

Comments
 (0)