Add quantized input support to cpu_sdpa#18649
Add quantized input support to cpu_sdpa#18649kimishpatel wants to merge 2 commits intogh/kimishpatel/222/basefrom
Conversation
cpu_sdpa (unfused SDPA) previously only supported float inputs. When the model uses quantized Q/K/V (int8 with per-channel scales and zero_points), decode fell back to cpu_flash_attention, missing the ~25-30% throughput improvement from unfused SDPA. This adds quantized support to cpu_sdpa by: - Accepting optional quantization params (zero_points, scales for Q/K/V) - Using _q_at_k_gemm for Q@K^T (handles both int8 and float) - Using _qk_at_v_gemm for scores@V (handles both int8 and float) - Applying scaling factor separately (fused with mask add or max reduction) - Allocating a dequantization buffer for V when quantized The dispatch in op_sdpa.cpp is updated to route quantized decode (seq_len==1) through cpu_sdpa instead of cpu_flash_attention. Differential Revision: [D96044310](https://our.internmc.facebook.com/intern/diff/D96044310/) [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18649
Note: Links to docs will display an error until the docs builds have been completed. ❌ 2 New Failures, 3 Cancelled JobsAs of commit 01ab6c4 with merge base fb1618e ( NEW FAILURES - The following jobs have failed:
CANCELLED JOBS - The following jobs were cancelled. Please retry:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
digantdesai
left a comment
There was a problem hiding this comment.
Review automatically exported from Phabricator review in Meta.
cpu_sdpa (unfused SDPA) previously only supported float inputs. When the model uses quantized Q/K/V (int8 with per-channel scales and zero_points), decode fell back to cpu_flash_attention, missing the ~25-30% throughput improvement from unfused SDPA. This adds quantized support to cpu_sdpa by: - Accepting optional quantization params (zero_points, scales for Q/K/V) - Using _q_at_k_gemm for QK^T (handles both int8 and float) - Using _qk_at_v_gemm for scoresV (handles both int8 and float) - Applying scaling factor separately (fused with mask add or max reduction) - Allocating a dequantization buffer for V when quantized The dispatch in op_sdpa.cpp is updated to route quantized decode (seq_len==1) through cpu_sdpa instead of cpu_flash_attention. Differential Revision: [D96044310](https://our.internmc.facebook.com/intern/diff/D96044310/) [ghstack-poisoned]
Stack from ghstack (oldest at bottom):
cpu_sdpa (unfused SDPA) previously only supported float inputs.
When the model uses quantized Q/K/V (int8 with per-channel scales
and zero_points), decode fell back to cpu_flash_attention, missing
the ~25-30% throughput improvement from unfused SDPA.
This adds quantized support to cpu_sdpa by:
The dispatch in op_sdpa.cpp is updated to route quantized decode
(seq_len==1) through cpu_sdpa instead of cpu_flash_attention.
Differential Revision: D96044310