Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
269 changes: 267 additions & 2 deletions extension/llm/custom_ops/op_custom_sdpa_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
*/

// Tests for the unfused SDPA code path (cpu_sdpa) dispatched when
// seq_len == 1 and inputs are non-quantized (the decode fast-path).
// These call custom_sdpa_out directly, not through sdpa_with_kv_cache.
// seq_len == 1 (the decode fast-path). Covers both float and quantized
// inputs. These call custom_sdpa_out / custom_quantized_sdpa_out
// directly, not through sdpa_with_kv_cache.

#include <algorithm>
#include <cmath>
Expand Down Expand Up @@ -117,6 +118,73 @@ void compute_reference_sdpa(
}
}

/**
* Dequantize int8 tensor in [B, S, H, D] layout using per-token
* scales/zero_points in [B, S, H, 1] layout.
* dequant(x) = (x - zero_point) * scale
*/
void dequantize_per_token(
const int8_t* data,
int B,
int S,
int H,
int D,
const float* scales,
const int8_t* zps,
float* out) {
for (int b = 0; b < B; b++) {
for (int s = 0; s < S; s++) {
for (int h = 0; h < H; h++) {
int param_idx = b * S * H + s * H + h;
float sc = scales[param_idx];
float zp = static_cast<float>(zps[param_idx]);
for (int d = 0; d < D; d++) {
int idx = b * S * H * D + s * H * D + h * D + d;
out[idx] = (static_cast<float>(data[idx]) - zp) * sc;
}
}
}
}
}

// Helper: call custom_quantized_sdpa_out. Inputs use [B, S, H, D] layout.
executorch::aten::Tensor call_custom_quantized_sdpa(
const executorch::aten::Tensor& q,
const executorch::aten::Tensor& k,
const executorch::aten::Tensor& v,
int64_t start_pos,
const std::optional<executorch::aten::Tensor>& attn_mask,
double dropout_p,
bool is_causal,
std::optional<double> scale,
const std::optional<executorch::aten::Tensor>& q_zp,
const std::optional<executorch::aten::Tensor>& q_sc,
const std::optional<executorch::aten::Tensor>& k_zp,
const std::optional<executorch::aten::Tensor>& k_sc,
const std::optional<executorch::aten::Tensor>& v_zp,
const std::optional<executorch::aten::Tensor>& v_sc,
executorch::aten::Tensor& out) {
executorch::runtime::KernelRuntimeContext ctx{};
return torch::executor::native::custom_quantized_sdpa_out(
ctx,
q,
k,
v,
start_pos,
attn_mask,
dropout_p,
is_causal,
scale,
q_zp,
q_sc,
k_zp,
k_sc,
v_zp,
v_sc,
/*is_seq_at_dim_1=*/false,
out);
}

} // namespace

// With a single KV entry (start_pos=0), output must equal V[0].
Expand Down Expand Up @@ -290,3 +358,200 @@ TEST(OpCustomSdpaTest, DecodeCausalMatchesNonCausal) {

EXPECT_TENSOR_CLOSE_WITH_TOL(out_c, out_nc, 1e-6, 1e-6);
}

// Quantized decode: int8 Q/K/V with per-token scales and zero_points,
// verified against dequantize-then-float-SDPA reference.
TEST(OpCustomSdpaTest, DecodeQuantized) {
TensorFactory<executorch::aten::ScalarType::Char> tfChar;
TensorFactory<executorch::aten::ScalarType::Float> tfFloat;

// Q: [B=1, S=1, H=2, D=4] as int8
auto q = tfChar.make({1, 1, 2, 4}, {10, 20, -5, 15, -10, 5, 25, -20});

// K: [B=1, kv_len=3, H=2, D=4] as int8
auto k = tfChar.make(
{1, 3, 2, 4}, {8, -12, 18, 5, -3, 22, -8, 14, 15, 7, -20, 10,
12, -15, 9, 6, -5, 25, 3, -10, 20, 8, -12, 17});

// V: [B=1, kv_len=3, H=2, D=4] as int8
auto v = tfChar.make(
{1, 3, 2, 4}, {5, 15, -8, 20, 10, -5, 18, 12, -12, 8, 22, -3,
7, 20, -10, 15, 18, -5, 10, 3, -8, 12, 5, -20});

// Per-token scales [B, S/kv, H, 1] and zero_points [B, S/kv, H, 1]
auto q_sc = tfFloat.make({1, 1, 2, 1}, {0.05f, 0.05f});
auto k_sc =
tfFloat.make({1, 3, 2, 1}, {0.05f, 0.05f, 0.05f, 0.05f, 0.05f, 0.05f});
auto v_sc =
tfFloat.make({1, 3, 2, 1}, {0.05f, 0.05f, 0.05f, 0.05f, 0.05f, 0.05f});
auto q_zp = tfChar.make({1, 1, 2, 1}, {0, 0});
auto k_zp = tfChar.make({1, 3, 2, 1}, {0, 0, 0, 0, 0, 0});
auto v_zp = tfChar.make({1, 3, 2, 1}, {0, 0, 0, 0, 0, 0});

int64_t start_pos = 2;
int num_valid = 3;

// Dequantize and compute float reference
std::vector<float> q_deq(8), k_deq(24), v_deq(24);
dequantize_per_token(
q.const_data_ptr<int8_t>(),
1,
1,
2,
4,
q_sc.const_data_ptr<float>(),
q_zp.const_data_ptr<int8_t>(),
q_deq.data());
dequantize_per_token(
k.const_data_ptr<int8_t>(),
1,
3,
2,
4,
k_sc.const_data_ptr<float>(),
k_zp.const_data_ptr<int8_t>(),
k_deq.data());
dequantize_per_token(
v.const_data_ptr<int8_t>(),
1,
3,
2,
4,
v_sc.const_data_ptr<float>(),
v_zp.const_data_ptr<int8_t>(),
v_deq.data());

std::vector<float> ref(8, 0.0f);
compute_reference_sdpa(
q_deq.data(),
1,
1,
2,
4,
k_deq.data(),
3,
2,
v_deq.data(),
ref.data(),
false,
start_pos,
num_valid);

auto expected = tfFloat.make({1, 1, 2, 4}, ref);
auto out = tfFloat.zeros({1, 1, 2, 4});
call_custom_quantized_sdpa(
q,
k,
v,
start_pos,
{},
0.0,
false,
{},
q_zp,
q_sc,
k_zp,
k_sc,
v_zp,
v_sc,
out);
EXPECT_TENSOR_CLOSE_WITH_TOL(out, expected, 1e-3, 1e-3);
}

// Quantized GQA decode: 4 query heads sharing 2 KV heads, int8 inputs.
TEST(OpCustomSdpaTest, DecodeQuantizedGQA) {
TensorFactory<executorch::aten::ScalarType::Char> tfChar;
TensorFactory<executorch::aten::ScalarType::Float> tfFloat;

// Q: [B=1, S=1, H_q=4, D=4] as int8
auto q = tfChar.make(
{1, 1, 4, 4},
{10, 20, -5, 15, -10, 5, 25, -20, 8, -3, 12, 7, -15, 18, 4, -8});

// K: [B=1, kv_len=3, H_kv=2, D=4] as int8
auto k = tfChar.make(
{1, 3, 2, 4}, {8, -12, 18, 5, -3, 22, -8, 14, 15, 7, -20, 10,
12, -15, 9, 6, -5, 25, 3, -10, 20, 8, -12, 17});

// V: [B=1, kv_len=3, H_kv=2, D=4] as int8
auto v = tfChar.make(
{1, 3, 2, 4}, {5, 15, -8, 20, 10, -5, 18, 12, -12, 8, 22, -3,
7, 20, -10, 15, 18, -5, 10, 3, -8, 12, 5, -20});

auto q_sc = tfFloat.make({1, 1, 4, 1}, {0.05f, 0.05f, 0.05f, 0.05f});
auto k_sc =
tfFloat.make({1, 3, 2, 1}, {0.05f, 0.05f, 0.05f, 0.05f, 0.05f, 0.05f});
auto v_sc =
tfFloat.make({1, 3, 2, 1}, {0.05f, 0.05f, 0.05f, 0.05f, 0.05f, 0.05f});
auto q_zp = tfChar.make({1, 1, 4, 1}, {0, 0, 0, 0});
auto k_zp = tfChar.make({1, 3, 2, 1}, {0, 0, 0, 0, 0, 0});
auto v_zp = tfChar.make({1, 3, 2, 1}, {0, 0, 0, 0, 0, 0});

int64_t start_pos = 2;
int num_valid = 3;

std::vector<float> q_deq(16), k_deq(24), v_deq(24);
dequantize_per_token(
q.const_data_ptr<int8_t>(),
1,
1,
4,
4,
q_sc.const_data_ptr<float>(),
q_zp.const_data_ptr<int8_t>(),
q_deq.data());
dequantize_per_token(
k.const_data_ptr<int8_t>(),
1,
3,
2,
4,
k_sc.const_data_ptr<float>(),
k_zp.const_data_ptr<int8_t>(),
k_deq.data());
dequantize_per_token(
v.const_data_ptr<int8_t>(),
1,
3,
2,
4,
v_sc.const_data_ptr<float>(),
v_zp.const_data_ptr<int8_t>(),
v_deq.data());

std::vector<float> ref(16, 0.0f);
compute_reference_sdpa(
q_deq.data(),
1,
1,
4,
4,
k_deq.data(),
3,
2,
v_deq.data(),
ref.data(),
false,
start_pos,
num_valid);

auto expected = tfFloat.make({1, 1, 4, 4}, ref);
auto out = tfFloat.zeros({1, 1, 4, 4});
call_custom_quantized_sdpa(
q,
k,
v,
start_pos,
{},
0.0,
false,
{},
q_zp,
q_sc,
k_zp,
k_sc,
v_zp,
v_sc,
out);
EXPECT_TENSOR_CLOSE_WITH_TOL(out, expected, 1e-3, 1e-3);
}
10 changes: 8 additions & 2 deletions extension/llm/custom_ops/op_sdpa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ Tensor& custom_sdpa_out_impl(
InvalidArgument,
output);

bool use_unfused_sdpa = q.scalar_type() != ScalarType::Char && seq_len == 1;
bool use_unfused_sdpa = seq_len == 1;
if (use_unfused_sdpa) {
ET_SWITCH_FLOAT_TYPES(output.scalar_type(), ctx, "sdpa", CTYPE, [&] {
sdpa::impl::cpu_sdpa<CTYPE>(
Expand All @@ -426,7 +426,13 @@ Tensor& custom_sdpa_out_impl(
scale,
seq_dim,
start_pos,
num_keys_for_causal_attention);
num_keys_for_causal_attention,
q_zero_points,
q_scales,
k_zero_points,
k_scales,
v_zero_points,
v_scales);
});
} else {
ET_SWITCH_FLOAT_TYPES(
Expand Down
Loading
Loading