Skip to content

Commit 6a29b58

Browse files
committed
vulkan: add turbo3 backend tests
- test_turbo_wht: forward/inverse WHT, 18 configs. NMSE tolerance 1e-5 (f32 SIMD reduction order varies across GPU backends). - test_turbo_wht_roundtrip: forward then inverse recovers original, 9 configs. NMSE tolerance 1e-5. - test_set_rows_turbo3: full quantization round-trip at small and large tensor sizes. Large tensors exercise the 2D dispatch grid. 21 configs. - Existing: test_turbo_wht (18), FA with turbo3 KV (528). - Total: 576 tests.
1 parent 1ebda4c commit 6a29b58

1 file changed

Lines changed: 155 additions & 2 deletions

File tree

tests/test-backend-ops.cpp

Lines changed: 155 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6168,6 +6168,126 @@ struct test_leaky_relu : public test_case {
61686168
}
61696169
};
61706170

6171+
// GGML_OP_TURBO_WHT
6172+
struct test_turbo_wht : public test_case {
6173+
const int64_t head_dim;
6174+
const int64_t n_heads;
6175+
const int direction; // 0=forward, 1=inverse
6176+
6177+
std::string vars() override {
6178+
return VARS_TO_STR3(head_dim, n_heads, direction);
6179+
}
6180+
6181+
double max_nmse_err() override {
6182+
return 1e-5; // f32 SIMD reduction order varies across GPU backends
6183+
}
6184+
6185+
test_turbo_wht(int64_t head_dim = 128, int64_t n_heads = 4, int direction = 0)
6186+
: head_dim(head_dim), n_heads(n_heads), direction(direction) {}
6187+
6188+
ggml_tensor * build_graph(ggml_context * ctx) override {
6189+
ggml_tensor * a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, head_dim, n_heads);
6190+
ggml_set_param(a);
6191+
ggml_set_name(a, "a");
6192+
ggml_tensor * out = ggml_turbo_wht(ctx, a, direction, 0, nullptr);
6193+
ggml_set_name(out, "out");
6194+
return out;
6195+
}
6196+
};
6197+
6198+
// GGML_OP_TURBO_WHT round-trip: forward then inverse should recover the original
6199+
struct test_turbo_wht_roundtrip : public test_case {
6200+
const int64_t head_dim;
6201+
const int64_t n_heads;
6202+
6203+
std::string vars() override {
6204+
return VARS_TO_STR2(head_dim, n_heads);
6205+
}
6206+
6207+
double max_nmse_err() override {
6208+
return 1e-5; // two WHT passes compound the f32 reduction error
6209+
}
6210+
6211+
test_turbo_wht_roundtrip(int64_t head_dim = 128, int64_t n_heads = 4)
6212+
: head_dim(head_dim), n_heads(n_heads) {}
6213+
6214+
ggml_tensor * build_graph(ggml_context * ctx) override {
6215+
ggml_tensor * a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, head_dim, n_heads);
6216+
ggml_set_param(a);
6217+
ggml_set_name(a, "a");
6218+
// forward WHT (direction=0), then inverse WHT (direction=1)
6219+
ggml_tensor * fwd = ggml_turbo_wht(ctx, a, 0, 0, nullptr);
6220+
ggml_tensor * inv = ggml_turbo_wht(ctx, fwd, 1, 0, nullptr);
6221+
ggml_set_name(inv, "out");
6222+
return inv;
6223+
}
6224+
};
6225+
6226+
// Test SET_ROWS with turbo3 destination, then dequantize and compare.
6227+
// This validates the full quantization pipeline: f32 -> WHT -> PolarQuant -> turbo3
6228+
// followed by dequantization: turbo3 -> f32. The round-trip error should be bounded.
6229+
// Unlike the generic SET_ROWS test (which compares raw quantized bytes), this test
6230+
// compares the dequantized f32 output, tolerating the lossy quantization error.
6231+
struct test_set_rows_turbo3 : public test_case {
6232+
const ggml_type type_idx;
6233+
const int64_t ne0; // head dim (must be multiple of 128)
6234+
const int64_t ne1; // rows in dst
6235+
const int r; // rows to write
6236+
6237+
std::string vars() override {
6238+
return VARS_TO_STR4(type_idx, ne0, ne1, r);
6239+
}
6240+
6241+
std::string op_desc(ggml_tensor * t) override {
6242+
GGML_UNUSED(t);
6243+
return "SET_ROWS_TURBO3";
6244+
}
6245+
6246+
test_set_rows_turbo3(ggml_type type_idx = GGML_TYPE_I32,
6247+
int64_t ne0 = 128, int64_t ne1 = 8, int r = 4)
6248+
: type_idx(type_idx), ne0(ne0), ne1(ne1), r(r) {}
6249+
6250+
ggml_tensor * build_graph(ggml_context * ctx) override {
6251+
// dst: the turbo3 KV cache buffer
6252+
ggml_tensor * dst = ggml_new_tensor_2d(ctx, GGML_TYPE_TURBO3_0, ne0, ne1);
6253+
ggml_set_name(dst, "dst");
6254+
6255+
// src: f32 values to quantize into the cache
6256+
ggml_tensor * src = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, ne0, r);
6257+
ggml_set_name(src, "src");
6258+
6259+
// row indices
6260+
ggml_tensor * row_idxs = ggml_new_tensor_1d(ctx, type_idx, r);
6261+
ggml_set_name(row_idxs, "row_idxs");
6262+
6263+
// Write f32 data into turbo3 dst via SET_ROWS (includes WHT + quantize)
6264+
ggml_tensor * written = ggml_set_rows(ctx, dst, src, row_idxs);
6265+
6266+
// Read it back by dequantizing the written rows to f32
6267+
ggml_tensor * out = ggml_cpy(ctx, written, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, ne0, ne1));
6268+
ggml_set_name(out, "out");
6269+
return out;
6270+
}
6271+
6272+
void initialize_tensors(ggml_context * ctx) override {
6273+
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
6274+
if (t->type == GGML_TYPE_I64 || t->type == GGML_TYPE_I32) {
6275+
if (ggml_is_view_op(t->op)) continue;
6276+
init_set_rows_row_ids(t, ne1);
6277+
} else {
6278+
init_tensor_uniform(t);
6279+
}
6280+
}
6281+
}
6282+
6283+
double max_nmse_err() override {
6284+
// turbo3 is 3-bit quantization with WHT rotation.
6285+
// The round-trip error (f32 -> turbo3 -> f32) is higher than q8_0
6286+
// but bounded. Empirically ~0.02 NMSE for uniform[-1,1] data.
6287+
return 0.05;
6288+
}
6289+
};
6290+
61716291
// GGML_OP_FLASH_ATTN_EXT
61726292
struct test_flash_attn_ext : public test_case {
61736293
const int64_t hsk; // K head size
@@ -8585,6 +8705,38 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
85858705
}
85868706
}
85878707

8708+
// TURBO_WHT tests
8709+
for (int dir : {0, 1}) {
8710+
for (int64_t hd : {128, 256, 512}) {
8711+
for (int64_t nh : {1, 4, 8}) {
8712+
test_cases.emplace_back(new test_turbo_wht(hd, nh, dir));
8713+
}
8714+
}
8715+
}
8716+
8717+
// TURBO_WHT round-trip tests (forward then inverse = identity)
8718+
for (int64_t hd : {128, 256, 512}) {
8719+
for (int64_t nh : {1, 4, 8}) {
8720+
test_cases.emplace_back(new test_turbo_wht_roundtrip(hd, nh));
8721+
}
8722+
}
8723+
8724+
// SET_ROWS with turbo3 destination: quantize then dequant round-trip
8725+
// Small tensors (single-dim dispatch)
8726+
for (ggml_type idx_type : {GGML_TYPE_I32, GGML_TYPE_I64}) {
8727+
for (int64_t ne0 : {128, 256, 512}) {
8728+
for (int r : {1, 4, 7}) {
8729+
test_cases.emplace_back(new test_set_rows_turbo3(idx_type, ne0, 16, r));
8730+
}
8731+
}
8732+
}
8733+
// Large tensors -- exercises 2D dispatch grid (>512 workgroups),
8734+
// matching actual inference dimensions (4 kv_heads, batch=1024+)
8735+
test_cases.emplace_back(new test_set_rows_turbo3(GGML_TYPE_I32, 128, 4096, 1024));
8736+
test_cases.emplace_back(new test_set_rows_turbo3(GGML_TYPE_I32, 256, 2048, 512));
8737+
test_cases.emplace_back(new test_set_rows_turbo3(GGML_TYPE_I32, 512, 1024, 256));
8738+
8739+
85888740
for (int hsk : { 40, 64, 72, 80, 96, 128, 192, 256, 320, 512, 576 }) {
85898741
for (int hsv : { 40, 64, 72, 80, 96, 128, 192, 256, 512 }) {
85908742
if (hsk != 192 && hsk != 320 && hsk != 576 && hsk != hsv) continue;
@@ -8612,8 +8764,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
86128764
for (int nb : { 1, 3, 32, 75, }) {
86138765
for (ggml_prec prec : {GGML_PREC_F32, GGML_PREC_DEFAULT}) {
86148766
if (hsk != 128 && prec == GGML_PREC_DEFAULT) continue;
8615-
for (ggml_type type_KV : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
8616-
if (type_KV != GGML_TYPE_F16 && hsk != 64 && hsk != 72) continue;
8767+
for (ggml_type type_KV : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0, GGML_TYPE_TURBO3_0}) {
8768+
if (type_KV == GGML_TYPE_TURBO3_0 && hsk < 128) continue;
8769+
if (type_KV != GGML_TYPE_F16 && hsk != 64 && hsk != 72 && hsk != 128) continue;
86178770
test_cases.emplace_back(new test_flash_attn_ext(
86188771
hsk, hsv, nh, {nr2, nr3}, kv, nb, mask, sinks, max_bias, logit_softcap, prec, type_KV));
86198772
// run fewer test cases permuted

0 commit comments

Comments
 (0)