Skip to content

Commit 3d4ca38

Browse files
committed
Add quantized input support to cpu_sdpa
Pull Request resolved: #18649 cpu_sdpa (unfused SDPA) previously only supported float inputs. When the model uses quantized Q/K/V (int8 with per-channel scales and zero_points), decode fell back to cpu_flash_attention, missing the ~25-30% throughput improvement from unfused SDPA. This adds quantized support to cpu_sdpa by: - Accepting optional quantization params (zero_points, scales for Q/K/V) - Using _q_at_k_gemm for Q@K^T (handles both int8 and float) - Using _qk_at_v_gemm for scores@V (handles both int8 and float) - Applying scaling factor separately (fused with mask add or max reduction) - Allocating a dequantization buffer for V when quantized The dispatch in op_sdpa.cpp is updated to route quantized decode (seq_len==1) through cpu_sdpa instead of cpu_flash_attention. ghstack-source-id: 374666322 @exported-using-ghexport Differential Revision: [D96044310](https://our.internmc.facebook.com/intern/diff/D96044310/)
1 parent f3013bf commit 3d4ca38

3 files changed

Lines changed: 425 additions & 42 deletions

File tree

extension/llm/custom_ops/op_custom_sdpa_test.cpp

Lines changed: 267 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@
77
*/
88

99
// Tests for the unfused SDPA code path (cpu_sdpa) dispatched when
10-
// seq_len == 1 and inputs are non-quantized (the decode fast-path).
11-
// These call custom_sdpa_out directly, not through sdpa_with_kv_cache.
10+
// seq_len == 1 (the decode fast-path). Covers both float and quantized
11+
// inputs. These call custom_sdpa_out / custom_quantized_sdpa_out
12+
// directly, not through sdpa_with_kv_cache.
1213

1314
#include <algorithm>
1415
#include <cmath>
@@ -117,6 +118,73 @@ void compute_reference_sdpa(
117118
}
118119
}
119120

121+
/**
122+
* Dequantize int8 tensor in [B, S, H, D] layout using per-token
123+
* scales/zero_points in [B, S, H, 1] layout.
124+
* dequant(x) = (x - zero_point) * scale
125+
*/
126+
void dequantize_per_token(
127+
const int8_t* data,
128+
int B,
129+
int S,
130+
int H,
131+
int D,
132+
const float* scales,
133+
const int8_t* zps,
134+
float* out) {
135+
for (int b = 0; b < B; b++) {
136+
for (int s = 0; s < S; s++) {
137+
for (int h = 0; h < H; h++) {
138+
int param_idx = b * S * H + s * H + h;
139+
float sc = scales[param_idx];
140+
float zp = static_cast<float>(zps[param_idx]);
141+
for (int d = 0; d < D; d++) {
142+
int idx = b * S * H * D + s * H * D + h * D + d;
143+
out[idx] = (static_cast<float>(data[idx]) - zp) * sc;
144+
}
145+
}
146+
}
147+
}
148+
}
149+
150+
// Helper: call custom_quantized_sdpa_out. Inputs use [B, S, H, D] layout.
151+
executorch::aten::Tensor call_custom_quantized_sdpa(
152+
const executorch::aten::Tensor& q,
153+
const executorch::aten::Tensor& k,
154+
const executorch::aten::Tensor& v,
155+
int64_t start_pos,
156+
const std::optional<executorch::aten::Tensor>& attn_mask,
157+
double dropout_p,
158+
bool is_causal,
159+
std::optional<double> scale,
160+
const std::optional<executorch::aten::Tensor>& q_zp,
161+
const std::optional<executorch::aten::Tensor>& q_sc,
162+
const std::optional<executorch::aten::Tensor>& k_zp,
163+
const std::optional<executorch::aten::Tensor>& k_sc,
164+
const std::optional<executorch::aten::Tensor>& v_zp,
165+
const std::optional<executorch::aten::Tensor>& v_sc,
166+
executorch::aten::Tensor& out) {
167+
executorch::runtime::KernelRuntimeContext ctx{};
168+
return torch::executor::native::custom_quantized_sdpa_out(
169+
ctx,
170+
q,
171+
k,
172+
v,
173+
start_pos,
174+
attn_mask,
175+
dropout_p,
176+
is_causal,
177+
scale,
178+
q_zp,
179+
q_sc,
180+
k_zp,
181+
k_sc,
182+
v_zp,
183+
v_sc,
184+
/*is_seq_at_dim_1=*/false,
185+
out);
186+
}
187+
120188
} // namespace
121189

122190
// With a single KV entry (start_pos=0), output must equal V[0].
@@ -290,3 +358,200 @@ TEST(OpCustomSdpaTest, DecodeCausalMatchesNonCausal) {
290358

291359
EXPECT_TENSOR_CLOSE_WITH_TOL(out_c, out_nc, 1e-6, 1e-6);
292360
}
361+
362+
// Quantized decode: int8 Q/K/V with per-token scales and zero_points,
363+
// verified against dequantize-then-float-SDPA reference.
364+
TEST(OpCustomSdpaTest, DecodeQuantized) {
365+
TensorFactory<executorch::aten::ScalarType::Char> tfChar;
366+
TensorFactory<executorch::aten::ScalarType::Float> tfFloat;
367+
368+
// Q: [B=1, S=1, H=2, D=4] as int8
369+
auto q = tfChar.make({1, 1, 2, 4}, {10, 20, -5, 15, -10, 5, 25, -20});
370+
371+
// K: [B=1, kv_len=3, H=2, D=4] as int8
372+
auto k = tfChar.make(
373+
{1, 3, 2, 4}, {8, -12, 18, 5, -3, 22, -8, 14, 15, 7, -20, 10,
374+
12, -15, 9, 6, -5, 25, 3, -10, 20, 8, -12, 17});
375+
376+
// V: [B=1, kv_len=3, H=2, D=4] as int8
377+
auto v = tfChar.make(
378+
{1, 3, 2, 4}, {5, 15, -8, 20, 10, -5, 18, 12, -12, 8, 22, -3,
379+
7, 20, -10, 15, 18, -5, 10, 3, -8, 12, 5, -20});
380+
381+
// Per-token scales [B, S/kv, H, 1] and zero_points [B, S/kv, H, 1]
382+
auto q_sc = tfFloat.make({1, 1, 2, 1}, {0.05f, 0.05f});
383+
auto k_sc =
384+
tfFloat.make({1, 3, 2, 1}, {0.05f, 0.05f, 0.05f, 0.05f, 0.05f, 0.05f});
385+
auto v_sc =
386+
tfFloat.make({1, 3, 2, 1}, {0.05f, 0.05f, 0.05f, 0.05f, 0.05f, 0.05f});
387+
auto q_zp = tfChar.make({1, 1, 2, 1}, {0, 0});
388+
auto k_zp = tfChar.make({1, 3, 2, 1}, {0, 0, 0, 0, 0, 0});
389+
auto v_zp = tfChar.make({1, 3, 2, 1}, {0, 0, 0, 0, 0, 0});
390+
391+
int64_t start_pos = 2;
392+
int num_valid = 3;
393+
394+
// Dequantize and compute float reference
395+
std::vector<float> q_deq(8), k_deq(24), v_deq(24);
396+
dequantize_per_token(
397+
q.const_data_ptr<int8_t>(),
398+
1,
399+
1,
400+
2,
401+
4,
402+
q_sc.const_data_ptr<float>(),
403+
q_zp.const_data_ptr<int8_t>(),
404+
q_deq.data());
405+
dequantize_per_token(
406+
k.const_data_ptr<int8_t>(),
407+
1,
408+
3,
409+
2,
410+
4,
411+
k_sc.const_data_ptr<float>(),
412+
k_zp.const_data_ptr<int8_t>(),
413+
k_deq.data());
414+
dequantize_per_token(
415+
v.const_data_ptr<int8_t>(),
416+
1,
417+
3,
418+
2,
419+
4,
420+
v_sc.const_data_ptr<float>(),
421+
v_zp.const_data_ptr<int8_t>(),
422+
v_deq.data());
423+
424+
std::vector<float> ref(8, 0.0f);
425+
compute_reference_sdpa(
426+
q_deq.data(),
427+
1,
428+
1,
429+
2,
430+
4,
431+
k_deq.data(),
432+
3,
433+
2,
434+
v_deq.data(),
435+
ref.data(),
436+
false,
437+
start_pos,
438+
num_valid);
439+
440+
auto expected = tfFloat.make({1, 1, 2, 4}, ref);
441+
auto out = tfFloat.zeros({1, 1, 2, 4});
442+
call_custom_quantized_sdpa(
443+
q,
444+
k,
445+
v,
446+
start_pos,
447+
{},
448+
0.0,
449+
false,
450+
{},
451+
q_zp,
452+
q_sc,
453+
k_zp,
454+
k_sc,
455+
v_zp,
456+
v_sc,
457+
out);
458+
EXPECT_TENSOR_CLOSE_WITH_TOL(out, expected, 1e-3, 1e-3);
459+
}
460+
461+
// Quantized GQA decode: 4 query heads sharing 2 KV heads, int8 inputs.
462+
TEST(OpCustomSdpaTest, DecodeQuantizedGQA) {
463+
TensorFactory<executorch::aten::ScalarType::Char> tfChar;
464+
TensorFactory<executorch::aten::ScalarType::Float> tfFloat;
465+
466+
// Q: [B=1, S=1, H_q=4, D=4] as int8
467+
auto q = tfChar.make(
468+
{1, 1, 4, 4},
469+
{10, 20, -5, 15, -10, 5, 25, -20, 8, -3, 12, 7, -15, 18, 4, -8});
470+
471+
// K: [B=1, kv_len=3, H_kv=2, D=4] as int8
472+
auto k = tfChar.make(
473+
{1, 3, 2, 4}, {8, -12, 18, 5, -3, 22, -8, 14, 15, 7, -20, 10,
474+
12, -15, 9, 6, -5, 25, 3, -10, 20, 8, -12, 17});
475+
476+
// V: [B=1, kv_len=3, H_kv=2, D=4] as int8
477+
auto v = tfChar.make(
478+
{1, 3, 2, 4}, {5, 15, -8, 20, 10, -5, 18, 12, -12, 8, 22, -3,
479+
7, 20, -10, 15, 18, -5, 10, 3, -8, 12, 5, -20});
480+
481+
auto q_sc = tfFloat.make({1, 1, 4, 1}, {0.05f, 0.05f, 0.05f, 0.05f});
482+
auto k_sc =
483+
tfFloat.make({1, 3, 2, 1}, {0.05f, 0.05f, 0.05f, 0.05f, 0.05f, 0.05f});
484+
auto v_sc =
485+
tfFloat.make({1, 3, 2, 1}, {0.05f, 0.05f, 0.05f, 0.05f, 0.05f, 0.05f});
486+
auto q_zp = tfChar.make({1, 1, 4, 1}, {0, 0, 0, 0});
487+
auto k_zp = tfChar.make({1, 3, 2, 1}, {0, 0, 0, 0, 0, 0});
488+
auto v_zp = tfChar.make({1, 3, 2, 1}, {0, 0, 0, 0, 0, 0});
489+
490+
int64_t start_pos = 2;
491+
int num_valid = 3;
492+
493+
std::vector<float> q_deq(16), k_deq(24), v_deq(24);
494+
dequantize_per_token(
495+
q.const_data_ptr<int8_t>(),
496+
1,
497+
1,
498+
4,
499+
4,
500+
q_sc.const_data_ptr<float>(),
501+
q_zp.const_data_ptr<int8_t>(),
502+
q_deq.data());
503+
dequantize_per_token(
504+
k.const_data_ptr<int8_t>(),
505+
1,
506+
3,
507+
2,
508+
4,
509+
k_sc.const_data_ptr<float>(),
510+
k_zp.const_data_ptr<int8_t>(),
511+
k_deq.data());
512+
dequantize_per_token(
513+
v.const_data_ptr<int8_t>(),
514+
1,
515+
3,
516+
2,
517+
4,
518+
v_sc.const_data_ptr<float>(),
519+
v_zp.const_data_ptr<int8_t>(),
520+
v_deq.data());
521+
522+
std::vector<float> ref(16, 0.0f);
523+
compute_reference_sdpa(
524+
q_deq.data(),
525+
1,
526+
1,
527+
4,
528+
4,
529+
k_deq.data(),
530+
3,
531+
2,
532+
v_deq.data(),
533+
ref.data(),
534+
false,
535+
start_pos,
536+
num_valid);
537+
538+
auto expected = tfFloat.make({1, 1, 4, 4}, ref);
539+
auto out = tfFloat.zeros({1, 1, 4, 4});
540+
call_custom_quantized_sdpa(
541+
q,
542+
k,
543+
v,
544+
start_pos,
545+
{},
546+
0.0,
547+
false,
548+
{},
549+
q_zp,
550+
q_sc,
551+
k_zp,
552+
k_sc,
553+
v_zp,
554+
v_sc,
555+
out);
556+
EXPECT_TENSOR_CLOSE_WITH_TOL(out, expected, 1e-3, 1e-3);
557+
}

extension/llm/custom_ops/op_sdpa.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,7 @@ Tensor& custom_sdpa_out_impl(
412412
InvalidArgument,
413413
output);
414414

415-
bool use_unfused_sdpa = q.scalar_type() != ScalarType::Char && seq_len == 1;
415+
bool use_unfused_sdpa = seq_len == 1;
416416
if (use_unfused_sdpa) {
417417
ET_SWITCH_FLOAT_TYPES(output.scalar_type(), ctx, "sdpa", CTYPE, [&] {
418418
sdpa::impl::cpu_sdpa<CTYPE>(
@@ -426,7 +426,13 @@ Tensor& custom_sdpa_out_impl(
426426
scale,
427427
seq_dim,
428428
start_pos,
429-
num_keys_for_causal_attention);
429+
num_keys_for_causal_attention,
430+
q_zero_points,
431+
q_scales,
432+
k_zero_points,
433+
k_scales,
434+
v_zero_points,
435+
v_scales);
430436
});
431437
} else {
432438
ET_SWITCH_FLOAT_TYPES(

0 commit comments

Comments
 (0)