|
7 | 7 | */ |
8 | 8 |
|
9 | 9 | // Tests for the unfused SDPA code path (cpu_sdpa) dispatched when |
10 | | -// seq_len == 1 and inputs are non-quantized (the decode fast-path). |
11 | | -// These call custom_sdpa_out directly, not through sdpa_with_kv_cache. |
| 10 | +// seq_len == 1 (the decode fast-path). Covers both float and quantized |
| 11 | +// inputs. These call custom_sdpa_out / custom_quantized_sdpa_out |
| 12 | +// directly, not through sdpa_with_kv_cache. |
12 | 13 |
|
13 | 14 | #include <algorithm> |
14 | 15 | #include <cmath> |
@@ -114,6 +115,55 @@ void compute_reference_sdpa( |
114 | 115 | } |
115 | 116 | } |
116 | 117 |
|
| 118 | +/** |
| 119 | + * Dequantize int8 tensor in [B, S, H, D] layout using per-token |
| 120 | + * scales/zero_points in [B, S, H, 1] layout. |
| 121 | + * dequant(x) = (x - zero_point) * scale |
| 122 | + */ |
| 123 | +void dequantize_per_token( |
| 124 | + const int8_t* data, int B, int S, int H, int D, |
| 125 | + const float* scales, |
| 126 | + const int8_t* zps, |
| 127 | + float* out) { |
| 128 | + for (int b = 0; b < B; b++) { |
| 129 | + for (int s = 0; s < S; s++) { |
| 130 | + for (int h = 0; h < H; h++) { |
| 131 | + int param_idx = b * S * H + s * H + h; |
| 132 | + float sc = scales[param_idx]; |
| 133 | + float zp = static_cast<float>(zps[param_idx]); |
| 134 | + for (int d = 0; d < D; d++) { |
| 135 | + int idx = b * S * H * D + s * H * D + h * D + d; |
| 136 | + out[idx] = (static_cast<float>(data[idx]) - zp) * sc; |
| 137 | + } |
| 138 | + } |
| 139 | + } |
| 140 | + } |
| 141 | +} |
| 142 | + |
| 143 | +// Helper: call custom_quantized_sdpa_out. Inputs use [B, S, H, D] layout. |
| 144 | +executorch::aten::Tensor call_custom_quantized_sdpa( |
| 145 | + const executorch::aten::Tensor& q, |
| 146 | + const executorch::aten::Tensor& k, |
| 147 | + const executorch::aten::Tensor& v, |
| 148 | + int64_t start_pos, |
| 149 | + const std::optional<executorch::aten::Tensor>& attn_mask, |
| 150 | + double dropout_p, |
| 151 | + bool is_causal, |
| 152 | + std::optional<double> scale, |
| 153 | + const std::optional<executorch::aten::Tensor>& q_zp, |
| 154 | + const std::optional<executorch::aten::Tensor>& q_sc, |
| 155 | + const std::optional<executorch::aten::Tensor>& k_zp, |
| 156 | + const std::optional<executorch::aten::Tensor>& k_sc, |
| 157 | + const std::optional<executorch::aten::Tensor>& v_zp, |
| 158 | + const std::optional<executorch::aten::Tensor>& v_sc, |
| 159 | + executorch::aten::Tensor& out) { |
| 160 | + executorch::runtime::KernelRuntimeContext ctx{}; |
| 161 | + return torch::executor::native::custom_quantized_sdpa_out( |
| 162 | + ctx, q, k, v, start_pos, attn_mask, dropout_p, is_causal, scale, |
| 163 | + q_zp, q_sc, k_zp, k_sc, v_zp, v_sc, |
| 164 | + /*is_seq_at_dim_1=*/false, out); |
| 165 | +} |
| 166 | + |
117 | 167 | } // namespace |
118 | 168 |
|
119 | 169 | // With a single KV entry (start_pos=0), output must equal V[0]. |
@@ -263,3 +313,137 @@ TEST(OpCustomSdpaTest, DecodeCausalMatchesNonCausal) { |
263 | 313 |
|
264 | 314 | EXPECT_TENSOR_CLOSE_WITH_TOL(out_c, out_nc, 1e-6, 1e-6); |
265 | 315 | } |
| 316 | + |
| 317 | +// Quantized decode: int8 Q/K/V with per-token scales and zero_points, |
| 318 | +// verified against dequantize-then-float-SDPA reference. |
| 319 | +TEST(OpCustomSdpaTest, DecodeQuantized) { |
| 320 | + TensorFactory<executorch::aten::ScalarType::Char> tfChar; |
| 321 | + TensorFactory<executorch::aten::ScalarType::Float> tfFloat; |
| 322 | + |
| 323 | + // Q: [B=1, S=1, H=2, D=4] as int8 |
| 324 | + auto q = tfChar.make( |
| 325 | + {1, 1, 2, 4}, |
| 326 | + {10, 20, -5, 15, -10, 5, 25, -20}); |
| 327 | + |
| 328 | + // K: [B=1, kv_len=3, H=2, D=4] as int8 |
| 329 | + auto k = tfChar.make( |
| 330 | + {1, 3, 2, 4}, |
| 331 | + {8, -12, 18, 5, -3, 22, -8, 14, |
| 332 | + 15, 7, -20, 10, 12, -15, 9, 6, |
| 333 | + -5, 25, 3, -10, 20, 8, -12, 17}); |
| 334 | + |
| 335 | + // V: [B=1, kv_len=3, H=2, D=4] as int8 |
| 336 | + auto v = tfChar.make( |
| 337 | + {1, 3, 2, 4}, |
| 338 | + {5, 15, -8, 20, 10, -5, 18, 12, |
| 339 | + -12, 8, 22, -3, 7, 20, -10, 15, |
| 340 | + 18, -5, 10, 3, -8, 12, 5, -20}); |
| 341 | + |
| 342 | + // Per-token scales [B, S/kv, H, 1] and zero_points [B, S/kv, H, 1] |
| 343 | + auto q_sc = tfFloat.make({1, 1, 2, 1}, {0.05f, 0.05f}); |
| 344 | + auto k_sc = tfFloat.make({1, 3, 2, 1}, |
| 345 | + {0.05f, 0.05f, 0.05f, 0.05f, 0.05f, 0.05f}); |
| 346 | + auto v_sc = tfFloat.make({1, 3, 2, 1}, |
| 347 | + {0.05f, 0.05f, 0.05f, 0.05f, 0.05f, 0.05f}); |
| 348 | + auto q_zp = tfChar.make({1, 1, 2, 1}, {0, 0}); |
| 349 | + auto k_zp = tfChar.make({1, 3, 2, 1}, {0, 0, 0, 0, 0, 0}); |
| 350 | + auto v_zp = tfChar.make({1, 3, 2, 1}, {0, 0, 0, 0, 0, 0}); |
| 351 | + |
| 352 | + int64_t start_pos = 2; |
| 353 | + int num_valid = 3; |
| 354 | + |
| 355 | + // Dequantize and compute float reference |
| 356 | + std::vector<float> q_deq(8), k_deq(24), v_deq(24); |
| 357 | + dequantize_per_token( |
| 358 | + q.const_data_ptr<int8_t>(), 1, 1, 2, 4, |
| 359 | + q_sc.const_data_ptr<float>(), q_zp.const_data_ptr<int8_t>(), |
| 360 | + q_deq.data()); |
| 361 | + dequantize_per_token( |
| 362 | + k.const_data_ptr<int8_t>(), 1, 3, 2, 4, |
| 363 | + k_sc.const_data_ptr<float>(), k_zp.const_data_ptr<int8_t>(), |
| 364 | + k_deq.data()); |
| 365 | + dequantize_per_token( |
| 366 | + v.const_data_ptr<int8_t>(), 1, 3, 2, 4, |
| 367 | + v_sc.const_data_ptr<float>(), v_zp.const_data_ptr<int8_t>(), |
| 368 | + v_deq.data()); |
| 369 | + |
| 370 | + std::vector<float> ref(8, 0.0f); |
| 371 | + compute_reference_sdpa( |
| 372 | + q_deq.data(), 1, 1, 2, 4, |
| 373 | + k_deq.data(), 3, 2, |
| 374 | + v_deq.data(), |
| 375 | + ref.data(), false, start_pos, num_valid); |
| 376 | + |
| 377 | + auto expected = tfFloat.make({1, 1, 2, 4}, ref); |
| 378 | + auto out = tfFloat.zeros({1, 1, 2, 4}); |
| 379 | + call_custom_quantized_sdpa( |
| 380 | + q, k, v, start_pos, {}, 0.0, false, {}, |
| 381 | + q_zp, q_sc, k_zp, k_sc, v_zp, v_sc, out); |
| 382 | + EXPECT_TENSOR_CLOSE_WITH_TOL(out, expected, 1e-3, 1e-3); |
| 383 | +} |
| 384 | + |
| 385 | +// Quantized GQA decode: 4 query heads sharing 2 KV heads, int8 inputs. |
| 386 | +TEST(OpCustomSdpaTest, DecodeQuantizedGQA) { |
| 387 | + TensorFactory<executorch::aten::ScalarType::Char> tfChar; |
| 388 | + TensorFactory<executorch::aten::ScalarType::Float> tfFloat; |
| 389 | + |
| 390 | + // Q: [B=1, S=1, H_q=4, D=4] as int8 |
| 391 | + auto q = tfChar.make( |
| 392 | + {1, 1, 4, 4}, |
| 393 | + {10, 20, -5, 15, -10, 5, 25, -20, |
| 394 | + 8, -3, 12, 7, -15, 18, 4, -8}); |
| 395 | + |
| 396 | + // K: [B=1, kv_len=3, H_kv=2, D=4] as int8 |
| 397 | + auto k = tfChar.make( |
| 398 | + {1, 3, 2, 4}, |
| 399 | + {8, -12, 18, 5, -3, 22, -8, 14, |
| 400 | + 15, 7, -20, 10, 12, -15, 9, 6, |
| 401 | + -5, 25, 3, -10, 20, 8, -12, 17}); |
| 402 | + |
| 403 | + // V: [B=1, kv_len=3, H_kv=2, D=4] as int8 |
| 404 | + auto v = tfChar.make( |
| 405 | + {1, 3, 2, 4}, |
| 406 | + {5, 15, -8, 20, 10, -5, 18, 12, |
| 407 | + -12, 8, 22, -3, 7, 20, -10, 15, |
| 408 | + 18, -5, 10, 3, -8, 12, 5, -20}); |
| 409 | + |
| 410 | + auto q_sc = tfFloat.make({1, 1, 4, 1}, {0.05f, 0.05f, 0.05f, 0.05f}); |
| 411 | + auto k_sc = tfFloat.make({1, 3, 2, 1}, |
| 412 | + {0.05f, 0.05f, 0.05f, 0.05f, 0.05f, 0.05f}); |
| 413 | + auto v_sc = tfFloat.make({1, 3, 2, 1}, |
| 414 | + {0.05f, 0.05f, 0.05f, 0.05f, 0.05f, 0.05f}); |
| 415 | + auto q_zp = tfChar.make({1, 1, 4, 1}, {0, 0, 0, 0}); |
| 416 | + auto k_zp = tfChar.make({1, 3, 2, 1}, {0, 0, 0, 0, 0, 0}); |
| 417 | + auto v_zp = tfChar.make({1, 3, 2, 1}, {0, 0, 0, 0, 0, 0}); |
| 418 | + |
| 419 | + int64_t start_pos = 2; |
| 420 | + int num_valid = 3; |
| 421 | + |
| 422 | + std::vector<float> q_deq(16), k_deq(24), v_deq(24); |
| 423 | + dequantize_per_token( |
| 424 | + q.const_data_ptr<int8_t>(), 1, 1, 4, 4, |
| 425 | + q_sc.const_data_ptr<float>(), q_zp.const_data_ptr<int8_t>(), |
| 426 | + q_deq.data()); |
| 427 | + dequantize_per_token( |
| 428 | + k.const_data_ptr<int8_t>(), 1, 3, 2, 4, |
| 429 | + k_sc.const_data_ptr<float>(), k_zp.const_data_ptr<int8_t>(), |
| 430 | + k_deq.data()); |
| 431 | + dequantize_per_token( |
| 432 | + v.const_data_ptr<int8_t>(), 1, 3, 2, 4, |
| 433 | + v_sc.const_data_ptr<float>(), v_zp.const_data_ptr<int8_t>(), |
| 434 | + v_deq.data()); |
| 435 | + |
| 436 | + std::vector<float> ref(16, 0.0f); |
| 437 | + compute_reference_sdpa( |
| 438 | + q_deq.data(), 1, 1, 4, 4, |
| 439 | + k_deq.data(), 3, 2, |
| 440 | + v_deq.data(), |
| 441 | + ref.data(), false, start_pos, num_valid); |
| 442 | + |
| 443 | + auto expected = tfFloat.make({1, 1, 4, 4}, ref); |
| 444 | + auto out = tfFloat.zeros({1, 1, 4, 4}); |
| 445 | + call_custom_quantized_sdpa( |
| 446 | + q, k, v, start_pos, {}, 0.0, false, {}, |
| 447 | + q_zp, q_sc, k_zp, k_sc, v_zp, v_sc, out); |
| 448 | + EXPECT_TENSOR_CLOSE_WITH_TOL(out, expected, 1e-3, 1e-3); |
| 449 | +} |
0 commit comments