Skip to content

Commit adef0e2

Browse files
committed
Update on "Add quantized input support to cpu_sdpa"
cpu_sdpa (unfused SDPA) previously only supported float inputs. When the model uses quantized Q/K/V (int8 with per-channel scales and zero_points), decode fell back to cpu_flash_attention, missing the ~25-30% throughput improvement from unfused SDPA. This adds quantized support to cpu_sdpa by: - Accepting optional quantization params (zero_points, scales for Q/K/V) - Using _q_at_k_gemm for QK^T (handles both int8 and float) - Using _qk_at_v_gemm for scoresV (handles both int8 and float) - Applying scaling factor separately (fused with mask add or max reduction) - Allocating a dequantization buffer for V when quantized The dispatch in op_sdpa.cpp is updated to route quantized decode (seq_len==1) through cpu_sdpa instead of cpu_flash_attention. Differential Revision: [D96044310](https://our.internmc.facebook.com/intern/diff/D96044310/) [ghstack-poisoned]
2 parents 8ca4001 + 5c09eb4 commit adef0e2

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

.ci/scripts/test_lora.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ Okay, so I need to calculate 15% of 80."
139139
EXPECTED_QUANT_LORA_PREFIX="
140140
<|im_start|>user Calculate 15% of 80?<|im_end|><|im_start|>assistant
141141
To calculate 15% of 80, we can multiply 80 by 15/100.
142-
80 * 0.15 = 12.
142+
80 * 15/100 = 12.
143143
So, 15% of 80 is 12.
144144
#### 12
145145
The answer is: 12<|im_end|>"

0 commit comments

Comments
 (0)