|
7 | 7 | */ |
8 | 8 |
|
9 | 9 | // Tests for the unfused SDPA code path (cpu_sdpa) dispatched when |
10 | | -// seq_len == 1 and inputs are non-quantized (the decode fast-path). |
11 | | -// These call custom_sdpa_out directly, not through sdpa_with_kv_cache. |
| 10 | +// seq_len == 1 (the decode fast-path). Covers both float and quantized |
| 11 | +// inputs. These call custom_sdpa_out / custom_quantized_sdpa_out |
| 12 | +// directly, not through sdpa_with_kv_cache. |
12 | 13 |
|
13 | 14 | #include <algorithm> |
14 | 15 | #include <cmath> |
@@ -117,6 +118,73 @@ void compute_reference_sdpa( |
117 | 118 | } |
118 | 119 | } |
119 | 120 |
|
| 121 | +/** |
| 122 | + * Dequantize int8 tensor in [B, S, H, D] layout using per-token |
| 123 | + * scales/zero_points in [B, S, H, 1] layout. |
| 124 | + * dequant(x) = (x - zero_point) * scale |
| 125 | + */ |
| 126 | +void dequantize_per_token( |
| 127 | + const int8_t* data, |
| 128 | + int B, |
| 129 | + int S, |
| 130 | + int H, |
| 131 | + int D, |
| 132 | + const float* scales, |
| 133 | + const int8_t* zps, |
| 134 | + float* out) { |
| 135 | + for (int b = 0; b < B; b++) { |
| 136 | + for (int s = 0; s < S; s++) { |
| 137 | + for (int h = 0; h < H; h++) { |
| 138 | + int param_idx = b * S * H + s * H + h; |
| 139 | + float sc = scales[param_idx]; |
| 140 | + float zp = static_cast<float>(zps[param_idx]); |
| 141 | + for (int d = 0; d < D; d++) { |
| 142 | + int idx = b * S * H * D + s * H * D + h * D + d; |
| 143 | + out[idx] = (static_cast<float>(data[idx]) - zp) * sc; |
| 144 | + } |
| 145 | + } |
| 146 | + } |
| 147 | + } |
| 148 | +} |
| 149 | + |
| 150 | +// Helper: call custom_quantized_sdpa_out. Inputs use [B, S, H, D] layout. |
| 151 | +executorch::aten::Tensor call_custom_quantized_sdpa( |
| 152 | + const executorch::aten::Tensor& q, |
| 153 | + const executorch::aten::Tensor& k, |
| 154 | + const executorch::aten::Tensor& v, |
| 155 | + int64_t start_pos, |
| 156 | + const std::optional<executorch::aten::Tensor>& attn_mask, |
| 157 | + double dropout_p, |
| 158 | + bool is_causal, |
| 159 | + std::optional<double> scale, |
| 160 | + const std::optional<executorch::aten::Tensor>& q_zp, |
| 161 | + const std::optional<executorch::aten::Tensor>& q_sc, |
| 162 | + const std::optional<executorch::aten::Tensor>& k_zp, |
| 163 | + const std::optional<executorch::aten::Tensor>& k_sc, |
| 164 | + const std::optional<executorch::aten::Tensor>& v_zp, |
| 165 | + const std::optional<executorch::aten::Tensor>& v_sc, |
| 166 | + executorch::aten::Tensor& out) { |
| 167 | + executorch::runtime::KernelRuntimeContext ctx{}; |
| 168 | + return torch::executor::native::custom_quantized_sdpa_out( |
| 169 | + ctx, |
| 170 | + q, |
| 171 | + k, |
| 172 | + v, |
| 173 | + start_pos, |
| 174 | + attn_mask, |
| 175 | + dropout_p, |
| 176 | + is_causal, |
| 177 | + scale, |
| 178 | + q_zp, |
| 179 | + q_sc, |
| 180 | + k_zp, |
| 181 | + k_sc, |
| 182 | + v_zp, |
| 183 | + v_sc, |
| 184 | + /*is_seq_at_dim_1=*/false, |
| 185 | + out); |
| 186 | +} |
| 187 | + |
120 | 188 | } // namespace |
121 | 189 |
|
122 | 190 | // With a single KV entry (start_pos=0), output must equal V[0]. |
@@ -290,3 +358,200 @@ TEST(OpCustomSdpaTest, DecodeCausalMatchesNonCausal) { |
290 | 358 |
|
291 | 359 | EXPECT_TENSOR_CLOSE_WITH_TOL(out_c, out_nc, 1e-6, 1e-6); |
292 | 360 | } |
| 361 | + |
| 362 | +// Quantized decode: int8 Q/K/V with per-token scales and zero_points, |
| 363 | +// verified against dequantize-then-float-SDPA reference. |
| 364 | +TEST(OpCustomSdpaTest, DecodeQuantized) { |
| 365 | + TensorFactory<executorch::aten::ScalarType::Char> tfChar; |
| 366 | + TensorFactory<executorch::aten::ScalarType::Float> tfFloat; |
| 367 | + |
| 368 | + // Q: [B=1, S=1, H=2, D=4] as int8 |
| 369 | + auto q = tfChar.make({1, 1, 2, 4}, {10, 20, -5, 15, -10, 5, 25, -20}); |
| 370 | + |
| 371 | + // K: [B=1, kv_len=3, H=2, D=4] as int8 |
| 372 | + auto k = tfChar.make( |
| 373 | + {1, 3, 2, 4}, {8, -12, 18, 5, -3, 22, -8, 14, 15, 7, -20, 10, |
| 374 | + 12, -15, 9, 6, -5, 25, 3, -10, 20, 8, -12, 17}); |
| 375 | + |
| 376 | + // V: [B=1, kv_len=3, H=2, D=4] as int8 |
| 377 | + auto v = tfChar.make( |
| 378 | + {1, 3, 2, 4}, {5, 15, -8, 20, 10, -5, 18, 12, -12, 8, 22, -3, |
| 379 | + 7, 20, -10, 15, 18, -5, 10, 3, -8, 12, 5, -20}); |
| 380 | + |
| 381 | + // Per-token scales [B, S/kv, H, 1] and zero_points [B, S/kv, H, 1] |
| 382 | + auto q_sc = tfFloat.make({1, 1, 2, 1}, {0.05f, 0.05f}); |
| 383 | + auto k_sc = |
| 384 | + tfFloat.make({1, 3, 2, 1}, {0.05f, 0.05f, 0.05f, 0.05f, 0.05f, 0.05f}); |
| 385 | + auto v_sc = |
| 386 | + tfFloat.make({1, 3, 2, 1}, {0.05f, 0.05f, 0.05f, 0.05f, 0.05f, 0.05f}); |
| 387 | + auto q_zp = tfChar.make({1, 1, 2, 1}, {0, 0}); |
| 388 | + auto k_zp = tfChar.make({1, 3, 2, 1}, {0, 0, 0, 0, 0, 0}); |
| 389 | + auto v_zp = tfChar.make({1, 3, 2, 1}, {0, 0, 0, 0, 0, 0}); |
| 390 | + |
| 391 | + int64_t start_pos = 2; |
| 392 | + int num_valid = 3; |
| 393 | + |
| 394 | + // Dequantize and compute float reference |
| 395 | + std::vector<float> q_deq(8), k_deq(24), v_deq(24); |
| 396 | + dequantize_per_token( |
| 397 | + q.const_data_ptr<int8_t>(), |
| 398 | + 1, |
| 399 | + 1, |
| 400 | + 2, |
| 401 | + 4, |
| 402 | + q_sc.const_data_ptr<float>(), |
| 403 | + q_zp.const_data_ptr<int8_t>(), |
| 404 | + q_deq.data()); |
| 405 | + dequantize_per_token( |
| 406 | + k.const_data_ptr<int8_t>(), |
| 407 | + 1, |
| 408 | + 3, |
| 409 | + 2, |
| 410 | + 4, |
| 411 | + k_sc.const_data_ptr<float>(), |
| 412 | + k_zp.const_data_ptr<int8_t>(), |
| 413 | + k_deq.data()); |
| 414 | + dequantize_per_token( |
| 415 | + v.const_data_ptr<int8_t>(), |
| 416 | + 1, |
| 417 | + 3, |
| 418 | + 2, |
| 419 | + 4, |
| 420 | + v_sc.const_data_ptr<float>(), |
| 421 | + v_zp.const_data_ptr<int8_t>(), |
| 422 | + v_deq.data()); |
| 423 | + |
| 424 | + std::vector<float> ref(8, 0.0f); |
| 425 | + compute_reference_sdpa( |
| 426 | + q_deq.data(), |
| 427 | + 1, |
| 428 | + 1, |
| 429 | + 2, |
| 430 | + 4, |
| 431 | + k_deq.data(), |
| 432 | + 3, |
| 433 | + 2, |
| 434 | + v_deq.data(), |
| 435 | + ref.data(), |
| 436 | + false, |
| 437 | + start_pos, |
| 438 | + num_valid); |
| 439 | + |
| 440 | + auto expected = tfFloat.make({1, 1, 2, 4}, ref); |
| 441 | + auto out = tfFloat.zeros({1, 1, 2, 4}); |
| 442 | + call_custom_quantized_sdpa( |
| 443 | + q, |
| 444 | + k, |
| 445 | + v, |
| 446 | + start_pos, |
| 447 | + {}, |
| 448 | + 0.0, |
| 449 | + false, |
| 450 | + {}, |
| 451 | + q_zp, |
| 452 | + q_sc, |
| 453 | + k_zp, |
| 454 | + k_sc, |
| 455 | + v_zp, |
| 456 | + v_sc, |
| 457 | + out); |
| 458 | + EXPECT_TENSOR_CLOSE_WITH_TOL(out, expected, 1e-3, 1e-3); |
| 459 | +} |
| 460 | + |
| 461 | +// Quantized GQA decode: 4 query heads sharing 2 KV heads, int8 inputs. |
| 462 | +TEST(OpCustomSdpaTest, DecodeQuantizedGQA) { |
| 463 | + TensorFactory<executorch::aten::ScalarType::Char> tfChar; |
| 464 | + TensorFactory<executorch::aten::ScalarType::Float> tfFloat; |
| 465 | + |
| 466 | + // Q: [B=1, S=1, H_q=4, D=4] as int8 |
| 467 | + auto q = tfChar.make( |
| 468 | + {1, 1, 4, 4}, |
| 469 | + {10, 20, -5, 15, -10, 5, 25, -20, 8, -3, 12, 7, -15, 18, 4, -8}); |
| 470 | + |
| 471 | + // K: [B=1, kv_len=3, H_kv=2, D=4] as int8 |
| 472 | + auto k = tfChar.make( |
| 473 | + {1, 3, 2, 4}, {8, -12, 18, 5, -3, 22, -8, 14, 15, 7, -20, 10, |
| 474 | + 12, -15, 9, 6, -5, 25, 3, -10, 20, 8, -12, 17}); |
| 475 | + |
| 476 | + // V: [B=1, kv_len=3, H_kv=2, D=4] as int8 |
| 477 | + auto v = tfChar.make( |
| 478 | + {1, 3, 2, 4}, {5, 15, -8, 20, 10, -5, 18, 12, -12, 8, 22, -3, |
| 479 | + 7, 20, -10, 15, 18, -5, 10, 3, -8, 12, 5, -20}); |
| 480 | + |
| 481 | + auto q_sc = tfFloat.make({1, 1, 4, 1}, {0.05f, 0.05f, 0.05f, 0.05f}); |
| 482 | + auto k_sc = |
| 483 | + tfFloat.make({1, 3, 2, 1}, {0.05f, 0.05f, 0.05f, 0.05f, 0.05f, 0.05f}); |
| 484 | + auto v_sc = |
| 485 | + tfFloat.make({1, 3, 2, 1}, {0.05f, 0.05f, 0.05f, 0.05f, 0.05f, 0.05f}); |
| 486 | + auto q_zp = tfChar.make({1, 1, 4, 1}, {0, 0, 0, 0}); |
| 487 | + auto k_zp = tfChar.make({1, 3, 2, 1}, {0, 0, 0, 0, 0, 0}); |
| 488 | + auto v_zp = tfChar.make({1, 3, 2, 1}, {0, 0, 0, 0, 0, 0}); |
| 489 | + |
| 490 | + int64_t start_pos = 2; |
| 491 | + int num_valid = 3; |
| 492 | + |
| 493 | + std::vector<float> q_deq(16), k_deq(24), v_deq(24); |
| 494 | + dequantize_per_token( |
| 495 | + q.const_data_ptr<int8_t>(), |
| 496 | + 1, |
| 497 | + 1, |
| 498 | + 4, |
| 499 | + 4, |
| 500 | + q_sc.const_data_ptr<float>(), |
| 501 | + q_zp.const_data_ptr<int8_t>(), |
| 502 | + q_deq.data()); |
| 503 | + dequantize_per_token( |
| 504 | + k.const_data_ptr<int8_t>(), |
| 505 | + 1, |
| 506 | + 3, |
| 507 | + 2, |
| 508 | + 4, |
| 509 | + k_sc.const_data_ptr<float>(), |
| 510 | + k_zp.const_data_ptr<int8_t>(), |
| 511 | + k_deq.data()); |
| 512 | + dequantize_per_token( |
| 513 | + v.const_data_ptr<int8_t>(), |
| 514 | + 1, |
| 515 | + 3, |
| 516 | + 2, |
| 517 | + 4, |
| 518 | + v_sc.const_data_ptr<float>(), |
| 519 | + v_zp.const_data_ptr<int8_t>(), |
| 520 | + v_deq.data()); |
| 521 | + |
| 522 | + std::vector<float> ref(16, 0.0f); |
| 523 | + compute_reference_sdpa( |
| 524 | + q_deq.data(), |
| 525 | + 1, |
| 526 | + 1, |
| 527 | + 4, |
| 528 | + 4, |
| 529 | + k_deq.data(), |
| 530 | + 3, |
| 531 | + 2, |
| 532 | + v_deq.data(), |
| 533 | + ref.data(), |
| 534 | + false, |
| 535 | + start_pos, |
| 536 | + num_valid); |
| 537 | + |
| 538 | + auto expected = tfFloat.make({1, 1, 4, 4}, ref); |
| 539 | + auto out = tfFloat.zeros({1, 1, 4, 4}); |
| 540 | + call_custom_quantized_sdpa( |
| 541 | + q, |
| 542 | + k, |
| 543 | + v, |
| 544 | + start_pos, |
| 545 | + {}, |
| 546 | + 0.0, |
| 547 | + false, |
| 548 | + {}, |
| 549 | + q_zp, |
| 550 | + q_sc, |
| 551 | + k_zp, |
| 552 | + k_sc, |
| 553 | + v_zp, |
| 554 | + v_sc, |
| 555 | + out); |
| 556 | + EXPECT_TENSOR_CLOSE_WITH_TOL(out, expected, 1e-3, 1e-3); |
| 557 | +} |
0 commit comments