diff --git a/extension/llm/custom_ops/op_custom_sdpa_test.cpp b/extension/llm/custom_ops/op_custom_sdpa_test.cpp index 92472c7f64a..063d73e620e 100644 --- a/extension/llm/custom_ops/op_custom_sdpa_test.cpp +++ b/extension/llm/custom_ops/op_custom_sdpa_test.cpp @@ -7,8 +7,9 @@ */ // Tests for the unfused SDPA code path (cpu_sdpa) dispatched when -// seq_len == 1 and inputs are non-quantized (the decode fast-path). -// These call custom_sdpa_out directly, not through sdpa_with_kv_cache. +// seq_len == 1 (the decode fast-path). Covers both float and quantized +// inputs. These call custom_sdpa_out / custom_quantized_sdpa_out +// directly, not through sdpa_with_kv_cache. #include #include @@ -117,6 +118,73 @@ void compute_reference_sdpa( } } +/** + * Dequantize int8 tensor in [B, S, H, D] layout using per-token + * scales/zero_points in [B, S, H, 1] layout. + * dequant(x) = (x - zero_point) * scale + */ +void dequantize_per_token( + const int8_t* data, + int B, + int S, + int H, + int D, + const float* scales, + const int8_t* zps, + float* out) { + for (int b = 0; b < B; b++) { + for (int s = 0; s < S; s++) { + for (int h = 0; h < H; h++) { + int param_idx = b * S * H + s * H + h; + float sc = scales[param_idx]; + float zp = static_cast(zps[param_idx]); + for (int d = 0; d < D; d++) { + int idx = b * S * H * D + s * H * D + h * D + d; + out[idx] = (static_cast(data[idx]) - zp) * sc; + } + } + } + } +} + +// Helper: call custom_quantized_sdpa_out. Inputs use [B, S, H, D] layout. +executorch::aten::Tensor call_custom_quantized_sdpa( + const executorch::aten::Tensor& q, + const executorch::aten::Tensor& k, + const executorch::aten::Tensor& v, + int64_t start_pos, + const std::optional& attn_mask, + double dropout_p, + bool is_causal, + std::optional scale, + const std::optional& q_zp, + const std::optional& q_sc, + const std::optional& k_zp, + const std::optional& k_sc, + const std::optional& v_zp, + const std::optional& v_sc, + executorch::aten::Tensor& out) { + executorch::runtime::KernelRuntimeContext ctx{}; + return torch::executor::native::custom_quantized_sdpa_out( + ctx, + q, + k, + v, + start_pos, + attn_mask, + dropout_p, + is_causal, + scale, + q_zp, + q_sc, + k_zp, + k_sc, + v_zp, + v_sc, + /*is_seq_at_dim_1=*/false, + out); +} + } // namespace // With a single KV entry (start_pos=0), output must equal V[0]. @@ -290,3 +358,200 @@ TEST(OpCustomSdpaTest, DecodeCausalMatchesNonCausal) { EXPECT_TENSOR_CLOSE_WITH_TOL(out_c, out_nc, 1e-6, 1e-6); } + +// Quantized decode: int8 Q/K/V with per-token scales and zero_points, +// verified against dequantize-then-float-SDPA reference. +TEST(OpCustomSdpaTest, DecodeQuantized) { + TensorFactory tfChar; + TensorFactory tfFloat; + + // Q: [B=1, S=1, H=2, D=4] as int8 + auto q = tfChar.make({1, 1, 2, 4}, {10, 20, -5, 15, -10, 5, 25, -20}); + + // K: [B=1, kv_len=3, H=2, D=4] as int8 + auto k = tfChar.make( + {1, 3, 2, 4}, {8, -12, 18, 5, -3, 22, -8, 14, 15, 7, -20, 10, + 12, -15, 9, 6, -5, 25, 3, -10, 20, 8, -12, 17}); + + // V: [B=1, kv_len=3, H=2, D=4] as int8 + auto v = tfChar.make( + {1, 3, 2, 4}, {5, 15, -8, 20, 10, -5, 18, 12, -12, 8, 22, -3, + 7, 20, -10, 15, 18, -5, 10, 3, -8, 12, 5, -20}); + + // Per-token scales [B, S/kv, H, 1] and zero_points [B, S/kv, H, 1] + auto q_sc = tfFloat.make({1, 1, 2, 1}, {0.05f, 0.05f}); + auto k_sc = + tfFloat.make({1, 3, 2, 1}, {0.05f, 0.05f, 0.05f, 0.05f, 0.05f, 0.05f}); + auto v_sc = + tfFloat.make({1, 3, 2, 1}, {0.05f, 0.05f, 0.05f, 0.05f, 0.05f, 0.05f}); + auto q_zp = tfChar.make({1, 1, 2, 1}, {0, 0}); + auto k_zp = tfChar.make({1, 3, 2, 1}, {0, 0, 0, 0, 0, 0}); + auto v_zp = tfChar.make({1, 3, 2, 1}, {0, 0, 0, 0, 0, 0}); + + int64_t start_pos = 2; + int num_valid = 3; + + // Dequantize and compute float reference + std::vector q_deq(8), k_deq(24), v_deq(24); + dequantize_per_token( + q.const_data_ptr(), + 1, + 1, + 2, + 4, + q_sc.const_data_ptr(), + q_zp.const_data_ptr(), + q_deq.data()); + dequantize_per_token( + k.const_data_ptr(), + 1, + 3, + 2, + 4, + k_sc.const_data_ptr(), + k_zp.const_data_ptr(), + k_deq.data()); + dequantize_per_token( + v.const_data_ptr(), + 1, + 3, + 2, + 4, + v_sc.const_data_ptr(), + v_zp.const_data_ptr(), + v_deq.data()); + + std::vector ref(8, 0.0f); + compute_reference_sdpa( + q_deq.data(), + 1, + 1, + 2, + 4, + k_deq.data(), + 3, + 2, + v_deq.data(), + ref.data(), + false, + start_pos, + num_valid); + + auto expected = tfFloat.make({1, 1, 2, 4}, ref); + auto out = tfFloat.zeros({1, 1, 2, 4}); + call_custom_quantized_sdpa( + q, + k, + v, + start_pos, + {}, + 0.0, + false, + {}, + q_zp, + q_sc, + k_zp, + k_sc, + v_zp, + v_sc, + out); + EXPECT_TENSOR_CLOSE_WITH_TOL(out, expected, 1e-3, 1e-3); +} + +// Quantized GQA decode: 4 query heads sharing 2 KV heads, int8 inputs. +TEST(OpCustomSdpaTest, DecodeQuantizedGQA) { + TensorFactory tfChar; + TensorFactory tfFloat; + + // Q: [B=1, S=1, H_q=4, D=4] as int8 + auto q = tfChar.make( + {1, 1, 4, 4}, + {10, 20, -5, 15, -10, 5, 25, -20, 8, -3, 12, 7, -15, 18, 4, -8}); + + // K: [B=1, kv_len=3, H_kv=2, D=4] as int8 + auto k = tfChar.make( + {1, 3, 2, 4}, {8, -12, 18, 5, -3, 22, -8, 14, 15, 7, -20, 10, + 12, -15, 9, 6, -5, 25, 3, -10, 20, 8, -12, 17}); + + // V: [B=1, kv_len=3, H_kv=2, D=4] as int8 + auto v = tfChar.make( + {1, 3, 2, 4}, {5, 15, -8, 20, 10, -5, 18, 12, -12, 8, 22, -3, + 7, 20, -10, 15, 18, -5, 10, 3, -8, 12, 5, -20}); + + auto q_sc = tfFloat.make({1, 1, 4, 1}, {0.05f, 0.05f, 0.05f, 0.05f}); + auto k_sc = + tfFloat.make({1, 3, 2, 1}, {0.05f, 0.05f, 0.05f, 0.05f, 0.05f, 0.05f}); + auto v_sc = + tfFloat.make({1, 3, 2, 1}, {0.05f, 0.05f, 0.05f, 0.05f, 0.05f, 0.05f}); + auto q_zp = tfChar.make({1, 1, 4, 1}, {0, 0, 0, 0}); + auto k_zp = tfChar.make({1, 3, 2, 1}, {0, 0, 0, 0, 0, 0}); + auto v_zp = tfChar.make({1, 3, 2, 1}, {0, 0, 0, 0, 0, 0}); + + int64_t start_pos = 2; + int num_valid = 3; + + std::vector q_deq(16), k_deq(24), v_deq(24); + dequantize_per_token( + q.const_data_ptr(), + 1, + 1, + 4, + 4, + q_sc.const_data_ptr(), + q_zp.const_data_ptr(), + q_deq.data()); + dequantize_per_token( + k.const_data_ptr(), + 1, + 3, + 2, + 4, + k_sc.const_data_ptr(), + k_zp.const_data_ptr(), + k_deq.data()); + dequantize_per_token( + v.const_data_ptr(), + 1, + 3, + 2, + 4, + v_sc.const_data_ptr(), + v_zp.const_data_ptr(), + v_deq.data()); + + std::vector ref(16, 0.0f); + compute_reference_sdpa( + q_deq.data(), + 1, + 1, + 4, + 4, + k_deq.data(), + 3, + 2, + v_deq.data(), + ref.data(), + false, + start_pos, + num_valid); + + auto expected = tfFloat.make({1, 1, 4, 4}, ref); + auto out = tfFloat.zeros({1, 1, 4, 4}); + call_custom_quantized_sdpa( + q, + k, + v, + start_pos, + {}, + 0.0, + false, + {}, + q_zp, + q_sc, + k_zp, + k_sc, + v_zp, + v_sc, + out); + EXPECT_TENSOR_CLOSE_WITH_TOL(out, expected, 1e-3, 1e-3); +} diff --git a/extension/llm/custom_ops/op_sdpa.cpp b/extension/llm/custom_ops/op_sdpa.cpp index 17759fa6dd5..e3b12895926 100644 --- a/extension/llm/custom_ops/op_sdpa.cpp +++ b/extension/llm/custom_ops/op_sdpa.cpp @@ -412,7 +412,7 @@ Tensor& custom_sdpa_out_impl( InvalidArgument, output); - bool use_unfused_sdpa = q.scalar_type() != ScalarType::Char && seq_len == 1; + bool use_unfused_sdpa = seq_len == 1; if (use_unfused_sdpa) { ET_SWITCH_FLOAT_TYPES(output.scalar_type(), ctx, "sdpa", CTYPE, [&] { sdpa::impl::cpu_sdpa( @@ -426,7 +426,13 @@ Tensor& custom_sdpa_out_impl( scale, seq_dim, start_pos, - num_keys_for_causal_attention); + num_keys_for_causal_attention, + q_zero_points, + q_scales, + k_zero_points, + k_scales, + v_zero_points, + v_scales); }); } else { ET_SWITCH_FLOAT_TYPES( diff --git a/extension/llm/custom_ops/op_sdpa_impl.h b/extension/llm/custom_ops/op_sdpa_impl.h index 467af1c89f4..ed638d86ff6 100644 --- a/extension/llm/custom_ops/op_sdpa_impl.h +++ b/extension/llm/custom_ops/op_sdpa_impl.h @@ -1123,7 +1123,13 @@ void cpu_sdpa( const optional& scale, const SeqDim seq_dim, const int64_t start_pos, - const int64_t num_keys_for_causal_attention) { + const int64_t num_keys_for_causal_attention, + const optional& q_zero_points = nullopt, + const optional& q_scales = nullopt, + const optional& k_zero_points = nullopt, + const optional& k_scales = nullopt, + const optional& v_zero_points = nullopt, + const optional& v_scales = nullopt) { using accum_t = scalar_t; using Vec = vec::Vectorized; accum_t scaling_factor = static_cast(calculate_scale(query, scale)); @@ -1158,6 +1164,7 @@ void cpu_sdpa( int64_t num_reps = num_head / num_heads_kv; bool has_attn_mask = attn_mask.has_value() && attn_mask.value().numel(); + bool is_quantized_sdpa = query.scalar_type() == ScalarType::Char; // Extract strides, swapping seq/head dims based on seq_dim auto q_strides = query.strides(); @@ -1186,6 +1193,39 @@ void cpu_sdpa( mStrideM = m_strides[0]; } + int64_t q_quant_params_StrideB = 0; + int64_t q_quant_params_StrideH = 0; + int64_t q_quant_params_StrideM = 0; + int64_t k_quant_params_StrideB = 0; + int64_t k_quant_params_StrideH = 0; + int64_t k_quant_params_StrideN = 0; + int64_t v_quant_params_StrideB = 0; + int64_t v_quant_params_StrideH = 0; + int64_t v_quant_params_StrideN = 0; + + if (is_quantized_sdpa) { + auto q_qp_strides = q_zero_points.value().strides(); + q_quant_params_StrideB = q_qp_strides[0]; + q_quant_params_StrideH = + (seq_dim == SeqDim::ONE) ? q_qp_strides[2] : q_qp_strides[1]; + q_quant_params_StrideM = + (seq_dim == SeqDim::ONE) ? q_qp_strides[1] : q_qp_strides[2]; + + auto k_qp_strides = k_zero_points.value().strides(); + k_quant_params_StrideB = k_qp_strides[0]; + k_quant_params_StrideH = + (seq_dim == SeqDim::ONE) ? k_qp_strides[2] : k_qp_strides[1]; + k_quant_params_StrideN = + (seq_dim == SeqDim::ONE) ? k_qp_strides[1] : k_qp_strides[2]; + + auto v_qp_strides = v_zero_points.value().strides(); + v_quant_params_StrideB = v_qp_strides[0]; + v_quant_params_StrideH = + (seq_dim == SeqDim::ONE) ? v_qp_strides[2] : v_qp_strides[1]; + v_quant_params_StrideN = + (seq_dim == SeqDim::ONE) ? v_qp_strides[1] : v_qp_strides[2]; + } + // Allocate per-thread scores buffer: [qSize, kvSize] per (batch, head) #ifdef ET_USE_THREADPOOL int64_t num_thread = @@ -1207,6 +1247,23 @@ void cpu_sdpa( } accum_t* buf_data = reinterpret_cast(buf); + // Allocate dequantization buffer for V (used by _qk_at_v_gemm when m > 4) + int64_t size_per_thread_qdq_vec = kvSize * headSize; + std::unique_ptr allocated_buf_for_qdq; + accum_t* scratch_for_quant_dequant = nullptr; + if (is_quantized_sdpa) { + int64_t size_qdq_bytes = + size_per_thread_qdq_vec * num_thread * sizeof(accum_t); + Result scratch_qdq = ctx.allocate_temp(size_qdq_bytes, 64); + if (!scratch_qdq.ok()) { + allocated_buf_for_qdq = std::make_unique(size_qdq_bytes); + scratch_for_quant_dequant = + reinterpret_cast(allocated_buf_for_qdq.get()); + } else { + scratch_for_quant_dequant = reinterpret_cast(scratch_qdq.get()); + } + } + const scalar_t* q_data = query.const_data_ptr(); const scalar_t* k_data = key.const_data_ptr(); const scalar_t* v_data = value.const_data_ptr(); @@ -1217,47 +1274,85 @@ void cpu_sdpa( auto compute_lambda = [&](int64_t begin, int64_t end) { int64_t ompIdx = torch::executor::get_thread_num(); accum_t* scores = buf_data + ompIdx * scores_per_thread; + accum_t* buf_qdq_ptr = is_quantized_sdpa + ? scratch_for_quant_dequant + ompIdx * size_per_thread_qdq_vec + : nullptr; for (int64_t idx = begin; idx < end; ++idx) { int64_t b = idx / num_head; int64_t h = idx % num_head; int64_t kv_h = h / num_reps; - // Pointer to Q[b, h, :, :] and K[b, kv_h, :, :] with appropriate strides - const scalar_t* q_ptr = q_data + b * qStrideB + h * qStrideH; - const scalar_t* k_ptr = k_data + b * kStrideB + kv_h * kStrideH; - const scalar_t* v_ptr = v_data + b * vStrideB + kv_h * vStrideH; + const void* q_ptr; + const void* k_ptr; + const void* v_ptr; + const float* q_scales_ptr = nullptr; + const float* k_scales_ptr = nullptr; + const float* v_scales_ptr = nullptr; + const int8_t* q_zp_ptr = nullptr; + const int8_t* k_zp_ptr = nullptr; + const int8_t* v_zp_ptr = nullptr; + + int64_t q_offset = b * qStrideB + h * qStrideH; + int64_t k_offset = b * kStrideB + kv_h * kStrideH; + int64_t v_offset = b * vStrideB + kv_h * vStrideH; + + if (is_quantized_sdpa) { + q_ptr = reinterpret_cast(q_data) + q_offset; + k_ptr = reinterpret_cast(k_data) + k_offset; + v_ptr = reinterpret_cast(v_data) + v_offset; + + int64_t q_qp_offset = + b * q_quant_params_StrideB + h * q_quant_params_StrideH; + int64_t k_qp_offset = + b * k_quant_params_StrideB + kv_h * k_quant_params_StrideH; + int64_t v_qp_offset = + b * v_quant_params_StrideB + kv_h * v_quant_params_StrideH; + + q_scales_ptr = q_scales.value().const_data_ptr() + q_qp_offset; + k_scales_ptr = k_scales.value().const_data_ptr() + k_qp_offset; + v_scales_ptr = v_scales.value().const_data_ptr() + v_qp_offset; + q_zp_ptr = q_zero_points.value().const_data_ptr() + q_qp_offset; + k_zp_ptr = k_zero_points.value().const_data_ptr() + k_qp_offset; + v_zp_ptr = v_zero_points.value().const_data_ptr() + v_qp_offset; + } else { + q_ptr = q_data + q_offset; + k_ptr = k_data + k_offset; + v_ptr = v_data + v_offset; + } scalar_t* o_ptr = out_data + b * oStrideB + h * oStrideH; - // GEMM 1: scores[qSize, kvSize] = scaling_factor * Q[qSize, D] @ K^T[D, - // kvSize] - ::executorch::cpublas::gemm( - ::executorch::cpublas::TransposeType::Transpose, - ::executorch::cpublas::TransposeType::NoTranspose, - kvSize, + // GEMM 1: scores[qSize, kvSize] = Q[qSize, D] @ K^T[D, kvSize] + MaybeQuantizedMatrixData q_matrix( + q_ptr, + q_zp_ptr, + q_scales_ptr, qSize, headSize, - scaling_factor, + q_quant_params_StrideM, + query.scalar_type()); + MaybeQuantizedMatrixData k_matrix( k_ptr, - kStrideN, - q_ptr, + k_zp_ptr, + k_scales_ptr, + kvSize, + headSize, + k_quant_params_StrideN, + key.scalar_type()); + _q_at_k_gemm( + qSize, + kvSize, + headSize, + q_matrix, qStrideM, - static_cast(0), - scores, - kvSize); + k_matrix, + kStrideN, + scores); - // Causal mask + attention mask + softmax per query row + // Causal mask + scaling + attention mask + softmax per query row for (int64_t qi = 0; qi < qSize; ++qi) { accum_t* row = scores + qi * kvSize; - // Apply attention mask if present - if (has_attn_mask) { - const accum_t* mask_row = mask_data + qi * mStrideM; - for (int64_t j = 0; j < kvSize; ++j) { - row[j] += mask_row[j]; - } - } - // Apply causal mask if (is_causal) { int64_t valid = std::min(start_pos + qi + 1, kvSize); @@ -1266,15 +1361,26 @@ void cpu_sdpa( } } - // Softmax: find max, compute exp, normalize - accum_t max_val = vec::reduce_all( - [](Vec& x, Vec& y) { return vec::maximum(x, y); }, row, kvSize); + accum_t max_val; + const int kvSizeInt = static_cast(kvSize); + if (has_attn_mask) { + // Apply scaling factor and attention mask in fusion + const accum_t* mask_row = mask_data + qi * mStrideM; + for (int64_t j = 0; j < kvSize; ++j) { + row[j] = row[j] * scaling_factor + mask_row[j]; + } + max_val = vec::reduce_all( + [](Vec& x, Vec& y) { return vec::maximum(x, y); }, row, kvSize); + } else { + // Apply scaling factor and find max in fusion + _mul_reduce_max_fusion_kernel( + row, scaling_factor, kvSizeInt, row, max_val); + } if (max_val == -std::numeric_limits::infinity()) { fill_stub(row, static_cast(0), kvSize); } else { accum_t sum_val = max_val; - const int kvSizeInt = static_cast(kvSize); _exp_reduce_sum_fusion_kernel(row, kvSizeInt, row, sum_val); accum_t inv_sum = static_cast(1) / sum_val; vec::map( @@ -1283,20 +1389,26 @@ void cpu_sdpa( } // GEMM 2: output[qSize, D] = scores[qSize, kvSize] @ V[kvSize, D] - ::executorch::cpublas::gemm( - ::executorch::cpublas::TransposeType::NoTranspose, - ::executorch::cpublas::TransposeType::NoTranspose, + MaybeQuantizedMatrixData v_matrix( + v_ptr, + v_zp_ptr, + v_scales_ptr, + kvSize, headSize, + v_quant_params_StrideN, + value.scalar_type()); + _qk_at_v_gemm( qSize, + headSize, kvSize, - static_cast(1), - v_ptr, - vStrideN, scores, kvSize, - static_cast(0), + v_matrix, + vStrideN, o_ptr, - oStrideM); + oStrideM, + static_cast(0), + buf_qdq_ptr); } }; torch::executor::parallel_for(0, batchSize * num_head, 1, compute_lambda);