Skip to content

Commit f541c79

Browse files
committed
Add quantized input support to cpu_sdpa
cpu_sdpa (unfused SDPA) previously only supported float inputs. When the model uses quantized Q/K/V (int8 with per-channel scales and zero_points), decode fell back to cpu_flash_attention, missing the ~25-30% throughput improvement from unfused SDPA. This adds quantized support to cpu_sdpa by: - Accepting optional quantization params (zero_points, scales for Q/K/V) - Using _q_at_k_gemm for QK^T (handles both int8 and float) - Using _qk_at_v_gemm for scoresV (handles both int8 and float) - Applying scaling factor separately (fused with mask add or max reduction) - Allocating a dequantization buffer for V when quantized The dispatch in op_sdpa.cpp is updated to route quantized decode (seq_len==1) through cpu_sdpa instead of cpu_flash_attention. Differential Revision: [D96044310](https://our.internmc.facebook.com/intern/diff/D96044310/) ghstack-source-id: 361224787 Pull Request resolved: #18649
1 parent 66b5c73 commit f541c79

File tree

3 files changed

+328
-43
lines changed

3 files changed

+328
-43
lines changed

extension/llm/custom_ops/op_custom_sdpa_test.cpp

Lines changed: 186 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@
77
*/
88

99
// Tests for the unfused SDPA code path (cpu_sdpa) dispatched when
10-
// seq_len == 1 and inputs are non-quantized (the decode fast-path).
11-
// These call custom_sdpa_out directly, not through sdpa_with_kv_cache.
10+
// seq_len == 1 (the decode fast-path). Covers both float and quantized
11+
// inputs. These call custom_sdpa_out / custom_quantized_sdpa_out
12+
// directly, not through sdpa_with_kv_cache.
1213

1314
#include <algorithm>
1415
#include <cmath>
@@ -114,6 +115,55 @@ void compute_reference_sdpa(
114115
}
115116
}
116117

118+
/**
119+
* Dequantize int8 tensor in [B, S, H, D] layout using per-token
120+
* scales/zero_points in [B, S, H, 1] layout.
121+
* dequant(x) = (x - zero_point) * scale
122+
*/
123+
void dequantize_per_token(
124+
const int8_t* data, int B, int S, int H, int D,
125+
const float* scales,
126+
const int8_t* zps,
127+
float* out) {
128+
for (int b = 0; b < B; b++) {
129+
for (int s = 0; s < S; s++) {
130+
for (int h = 0; h < H; h++) {
131+
int param_idx = b * S * H + s * H + h;
132+
float sc = scales[param_idx];
133+
float zp = static_cast<float>(zps[param_idx]);
134+
for (int d = 0; d < D; d++) {
135+
int idx = b * S * H * D + s * H * D + h * D + d;
136+
out[idx] = (static_cast<float>(data[idx]) - zp) * sc;
137+
}
138+
}
139+
}
140+
}
141+
}
142+
143+
// Helper: call custom_quantized_sdpa_out. Inputs use [B, S, H, D] layout.
144+
executorch::aten::Tensor call_custom_quantized_sdpa(
145+
const executorch::aten::Tensor& q,
146+
const executorch::aten::Tensor& k,
147+
const executorch::aten::Tensor& v,
148+
int64_t start_pos,
149+
const std::optional<executorch::aten::Tensor>& attn_mask,
150+
double dropout_p,
151+
bool is_causal,
152+
std::optional<double> scale,
153+
const std::optional<executorch::aten::Tensor>& q_zp,
154+
const std::optional<executorch::aten::Tensor>& q_sc,
155+
const std::optional<executorch::aten::Tensor>& k_zp,
156+
const std::optional<executorch::aten::Tensor>& k_sc,
157+
const std::optional<executorch::aten::Tensor>& v_zp,
158+
const std::optional<executorch::aten::Tensor>& v_sc,
159+
executorch::aten::Tensor& out) {
160+
executorch::runtime::KernelRuntimeContext ctx{};
161+
return torch::executor::native::custom_quantized_sdpa_out(
162+
ctx, q, k, v, start_pos, attn_mask, dropout_p, is_causal, scale,
163+
q_zp, q_sc, k_zp, k_sc, v_zp, v_sc,
164+
/*is_seq_at_dim_1=*/false, out);
165+
}
166+
117167
} // namespace
118168

119169
// With a single KV entry (start_pos=0), output must equal V[0].
@@ -263,3 +313,137 @@ TEST(OpCustomSdpaTest, DecodeCausalMatchesNonCausal) {
263313

264314
EXPECT_TENSOR_CLOSE_WITH_TOL(out_c, out_nc, 1e-6, 1e-6);
265315
}
316+
317+
// Quantized decode: int8 Q/K/V with per-token scales and zero_points,
318+
// verified against dequantize-then-float-SDPA reference.
319+
TEST(OpCustomSdpaTest, DecodeQuantized) {
320+
TensorFactory<executorch::aten::ScalarType::Char> tfChar;
321+
TensorFactory<executorch::aten::ScalarType::Float> tfFloat;
322+
323+
// Q: [B=1, S=1, H=2, D=4] as int8
324+
auto q = tfChar.make(
325+
{1, 1, 2, 4},
326+
{10, 20, -5, 15, -10, 5, 25, -20});
327+
328+
// K: [B=1, kv_len=3, H=2, D=4] as int8
329+
auto k = tfChar.make(
330+
{1, 3, 2, 4},
331+
{8, -12, 18, 5, -3, 22, -8, 14,
332+
15, 7, -20, 10, 12, -15, 9, 6,
333+
-5, 25, 3, -10, 20, 8, -12, 17});
334+
335+
// V: [B=1, kv_len=3, H=2, D=4] as int8
336+
auto v = tfChar.make(
337+
{1, 3, 2, 4},
338+
{5, 15, -8, 20, 10, -5, 18, 12,
339+
-12, 8, 22, -3, 7, 20, -10, 15,
340+
18, -5, 10, 3, -8, 12, 5, -20});
341+
342+
// Per-token scales [B, S/kv, H, 1] and zero_points [B, S/kv, H, 1]
343+
auto q_sc = tfFloat.make({1, 1, 2, 1}, {0.05f, 0.05f});
344+
auto k_sc = tfFloat.make({1, 3, 2, 1},
345+
{0.05f, 0.05f, 0.05f, 0.05f, 0.05f, 0.05f});
346+
auto v_sc = tfFloat.make({1, 3, 2, 1},
347+
{0.05f, 0.05f, 0.05f, 0.05f, 0.05f, 0.05f});
348+
auto q_zp = tfChar.make({1, 1, 2, 1}, {0, 0});
349+
auto k_zp = tfChar.make({1, 3, 2, 1}, {0, 0, 0, 0, 0, 0});
350+
auto v_zp = tfChar.make({1, 3, 2, 1}, {0, 0, 0, 0, 0, 0});
351+
352+
int64_t start_pos = 2;
353+
int num_valid = 3;
354+
355+
// Dequantize and compute float reference
356+
std::vector<float> q_deq(8), k_deq(24), v_deq(24);
357+
dequantize_per_token(
358+
q.const_data_ptr<int8_t>(), 1, 1, 2, 4,
359+
q_sc.const_data_ptr<float>(), q_zp.const_data_ptr<int8_t>(),
360+
q_deq.data());
361+
dequantize_per_token(
362+
k.const_data_ptr<int8_t>(), 1, 3, 2, 4,
363+
k_sc.const_data_ptr<float>(), k_zp.const_data_ptr<int8_t>(),
364+
k_deq.data());
365+
dequantize_per_token(
366+
v.const_data_ptr<int8_t>(), 1, 3, 2, 4,
367+
v_sc.const_data_ptr<float>(), v_zp.const_data_ptr<int8_t>(),
368+
v_deq.data());
369+
370+
std::vector<float> ref(8, 0.0f);
371+
compute_reference_sdpa(
372+
q_deq.data(), 1, 1, 2, 4,
373+
k_deq.data(), 3, 2,
374+
v_deq.data(),
375+
ref.data(), false, start_pos, num_valid);
376+
377+
auto expected = tfFloat.make({1, 1, 2, 4}, ref);
378+
auto out = tfFloat.zeros({1, 1, 2, 4});
379+
call_custom_quantized_sdpa(
380+
q, k, v, start_pos, {}, 0.0, false, {},
381+
q_zp, q_sc, k_zp, k_sc, v_zp, v_sc, out);
382+
EXPECT_TENSOR_CLOSE_WITH_TOL(out, expected, 1e-3, 1e-3);
383+
}
384+
385+
// Quantized GQA decode: 4 query heads sharing 2 KV heads, int8 inputs.
386+
TEST(OpCustomSdpaTest, DecodeQuantizedGQA) {
387+
TensorFactory<executorch::aten::ScalarType::Char> tfChar;
388+
TensorFactory<executorch::aten::ScalarType::Float> tfFloat;
389+
390+
// Q: [B=1, S=1, H_q=4, D=4] as int8
391+
auto q = tfChar.make(
392+
{1, 1, 4, 4},
393+
{10, 20, -5, 15, -10, 5, 25, -20,
394+
8, -3, 12, 7, -15, 18, 4, -8});
395+
396+
// K: [B=1, kv_len=3, H_kv=2, D=4] as int8
397+
auto k = tfChar.make(
398+
{1, 3, 2, 4},
399+
{8, -12, 18, 5, -3, 22, -8, 14,
400+
15, 7, -20, 10, 12, -15, 9, 6,
401+
-5, 25, 3, -10, 20, 8, -12, 17});
402+
403+
// V: [B=1, kv_len=3, H_kv=2, D=4] as int8
404+
auto v = tfChar.make(
405+
{1, 3, 2, 4},
406+
{5, 15, -8, 20, 10, -5, 18, 12,
407+
-12, 8, 22, -3, 7, 20, -10, 15,
408+
18, -5, 10, 3, -8, 12, 5, -20});
409+
410+
auto q_sc = tfFloat.make({1, 1, 4, 1}, {0.05f, 0.05f, 0.05f, 0.05f});
411+
auto k_sc = tfFloat.make({1, 3, 2, 1},
412+
{0.05f, 0.05f, 0.05f, 0.05f, 0.05f, 0.05f});
413+
auto v_sc = tfFloat.make({1, 3, 2, 1},
414+
{0.05f, 0.05f, 0.05f, 0.05f, 0.05f, 0.05f});
415+
auto q_zp = tfChar.make({1, 1, 4, 1}, {0, 0, 0, 0});
416+
auto k_zp = tfChar.make({1, 3, 2, 1}, {0, 0, 0, 0, 0, 0});
417+
auto v_zp = tfChar.make({1, 3, 2, 1}, {0, 0, 0, 0, 0, 0});
418+
419+
int64_t start_pos = 2;
420+
int num_valid = 3;
421+
422+
std::vector<float> q_deq(16), k_deq(24), v_deq(24);
423+
dequantize_per_token(
424+
q.const_data_ptr<int8_t>(), 1, 1, 4, 4,
425+
q_sc.const_data_ptr<float>(), q_zp.const_data_ptr<int8_t>(),
426+
q_deq.data());
427+
dequantize_per_token(
428+
k.const_data_ptr<int8_t>(), 1, 3, 2, 4,
429+
k_sc.const_data_ptr<float>(), k_zp.const_data_ptr<int8_t>(),
430+
k_deq.data());
431+
dequantize_per_token(
432+
v.const_data_ptr<int8_t>(), 1, 3, 2, 4,
433+
v_sc.const_data_ptr<float>(), v_zp.const_data_ptr<int8_t>(),
434+
v_deq.data());
435+
436+
std::vector<float> ref(16, 0.0f);
437+
compute_reference_sdpa(
438+
q_deq.data(), 1, 1, 4, 4,
439+
k_deq.data(), 3, 2,
440+
v_deq.data(),
441+
ref.data(), false, start_pos, num_valid);
442+
443+
auto expected = tfFloat.make({1, 1, 4, 4}, ref);
444+
auto out = tfFloat.zeros({1, 1, 4, 4});
445+
call_custom_quantized_sdpa(
446+
q, k, v, start_pos, {}, 0.0, false, {},
447+
q_zp, q_sc, k_zp, k_sc, v_zp, v_sc, out);
448+
EXPECT_TENSOR_CLOSE_WITH_TOL(out, expected, 1e-3, 1e-3);
449+
}

extension/llm/custom_ops/op_sdpa.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -412,15 +412,17 @@ Tensor& custom_sdpa_out_impl(
412412
InvalidArgument,
413413
output);
414414

415-
bool use_unfused_sdpa = q.scalar_type() != ScalarType::Char &&
416-
seq_len == 1;
415+
bool use_unfused_sdpa = seq_len == 1;
417416
if (use_unfused_sdpa) {
418417
ET_SWITCH_FLOAT_TYPES(
419418
output.scalar_type(), ctx, "sdpa", CTYPE, [&] {
420419
sdpa::impl::cpu_sdpa<CTYPE>(
421420
ctx, output, q, k, v, is_causal, attn_mask, scale,
422421
seq_dim,
423-
start_pos, num_keys_for_causal_attention);
422+
start_pos, num_keys_for_causal_attention,
423+
q_zero_points, q_scales,
424+
k_zero_points, k_scales,
425+
v_zero_points, v_scales);
424426
});
425427
} else {
426428
ET_SWITCH_FLOAT_TYPES(

0 commit comments

Comments
 (0)