@@ -6168,6 +6168,126 @@ struct test_leaky_relu : public test_case {
61686168 }
61696169};
61706170
6171+ // GGML_OP_TURBO_WHT
6172+ struct test_turbo_wht : public test_case {
6173+ const int64_t head_dim;
6174+ const int64_t n_heads;
6175+ const int direction; // 0=forward, 1=inverse
6176+
6177+ std::string vars () override {
6178+ return VARS_TO_STR3 (head_dim, n_heads, direction);
6179+ }
6180+
6181+ double max_nmse_err () override {
6182+ return 1e-5 ; // f32 SIMD reduction order varies across GPU backends
6183+ }
6184+
6185+ test_turbo_wht (int64_t head_dim = 128 , int64_t n_heads = 4 , int direction = 0 )
6186+ : head_dim(head_dim), n_heads(n_heads), direction(direction) {}
6187+
6188+ ggml_tensor * build_graph (ggml_context * ctx) override {
6189+ ggml_tensor * a = ggml_new_tensor_2d (ctx, GGML_TYPE_F32, head_dim, n_heads);
6190+ ggml_set_param (a);
6191+ ggml_set_name (a, " a" );
6192+ ggml_tensor * out = ggml_turbo_wht (ctx, a, direction, 0 , nullptr );
6193+ ggml_set_name (out, " out" );
6194+ return out;
6195+ }
6196+ };
6197+
6198+ // GGML_OP_TURBO_WHT round-trip: forward then inverse should recover the original
6199+ struct test_turbo_wht_roundtrip : public test_case {
6200+ const int64_t head_dim;
6201+ const int64_t n_heads;
6202+
6203+ std::string vars () override {
6204+ return VARS_TO_STR2 (head_dim, n_heads);
6205+ }
6206+
6207+ double max_nmse_err () override {
6208+ return 1e-5 ; // two WHT passes compound the f32 reduction error
6209+ }
6210+
6211+ test_turbo_wht_roundtrip (int64_t head_dim = 128 , int64_t n_heads = 4 )
6212+ : head_dim(head_dim), n_heads(n_heads) {}
6213+
6214+ ggml_tensor * build_graph (ggml_context * ctx) override {
6215+ ggml_tensor * a = ggml_new_tensor_2d (ctx, GGML_TYPE_F32, head_dim, n_heads);
6216+ ggml_set_param (a);
6217+ ggml_set_name (a, " a" );
6218+ // forward WHT (direction=0), then inverse WHT (direction=1)
6219+ ggml_tensor * fwd = ggml_turbo_wht (ctx, a, 0 , 0 , nullptr );
6220+ ggml_tensor * inv = ggml_turbo_wht (ctx, fwd, 1 , 0 , nullptr );
6221+ ggml_set_name (inv, " out" );
6222+ return inv;
6223+ }
6224+ };
6225+
6226+ // Test SET_ROWS with turbo3 destination, then dequantize and compare.
6227+ // This validates the full quantization pipeline: f32 -> WHT -> PolarQuant -> turbo3
6228+ // followed by dequantization: turbo3 -> f32. The round-trip error should be bounded.
6229+ // Unlike the generic SET_ROWS test (which compares raw quantized bytes), this test
6230+ // compares the dequantized f32 output, tolerating the lossy quantization error.
6231+ struct test_set_rows_turbo3 : public test_case {
6232+ const ggml_type type_idx;
6233+ const int64_t ne0; // head dim (must be multiple of 128)
6234+ const int64_t ne1; // rows in dst
6235+ const int r; // rows to write
6236+
6237+ std::string vars () override {
6238+ return VARS_TO_STR4 (type_idx, ne0, ne1, r);
6239+ }
6240+
6241+ std::string op_desc (ggml_tensor * t) override {
6242+ GGML_UNUSED (t);
6243+ return " SET_ROWS_TURBO3" ;
6244+ }
6245+
6246+ test_set_rows_turbo3 (ggml_type type_idx = GGML_TYPE_I32,
6247+ int64_t ne0 = 128 , int64_t ne1 = 8 , int r = 4 )
6248+ : type_idx(type_idx), ne0(ne0), ne1(ne1), r(r) {}
6249+
6250+ ggml_tensor * build_graph (ggml_context * ctx) override {
6251+ // dst: the turbo3 KV cache buffer
6252+ ggml_tensor * dst = ggml_new_tensor_2d (ctx, GGML_TYPE_TURBO3_0, ne0, ne1);
6253+ ggml_set_name (dst, " dst" );
6254+
6255+ // src: f32 values to quantize into the cache
6256+ ggml_tensor * src = ggml_new_tensor_2d (ctx, GGML_TYPE_F32, ne0, r);
6257+ ggml_set_name (src, " src" );
6258+
6259+ // row indices
6260+ ggml_tensor * row_idxs = ggml_new_tensor_1d (ctx, type_idx, r);
6261+ ggml_set_name (row_idxs, " row_idxs" );
6262+
6263+ // Write f32 data into turbo3 dst via SET_ROWS (includes WHT + quantize)
6264+ ggml_tensor * written = ggml_set_rows (ctx, dst, src, row_idxs);
6265+
6266+ // Read it back by dequantizing the written rows to f32
6267+ ggml_tensor * out = ggml_cpy (ctx, written, ggml_new_tensor_2d (ctx, GGML_TYPE_F32, ne0, ne1));
6268+ ggml_set_name (out, " out" );
6269+ return out;
6270+ }
6271+
6272+ void initialize_tensors (ggml_context * ctx) override {
6273+ for (ggml_tensor * t = ggml_get_first_tensor (ctx); t != NULL ; t = ggml_get_next_tensor (ctx, t)) {
6274+ if (t->type == GGML_TYPE_I64 || t->type == GGML_TYPE_I32) {
6275+ if (ggml_is_view_op (t->op )) continue ;
6276+ init_set_rows_row_ids (t, ne1);
6277+ } else {
6278+ init_tensor_uniform (t);
6279+ }
6280+ }
6281+ }
6282+
6283+ double max_nmse_err () override {
6284+ // turbo3 is 3-bit quantization with WHT rotation.
6285+ // The round-trip error (f32 -> turbo3 -> f32) is higher than q8_0
6286+ // but bounded. Empirically ~0.02 NMSE for uniform[-1,1] data.
6287+ return 0.05 ;
6288+ }
6289+ };
6290+
61716291// GGML_OP_FLASH_ATTN_EXT
61726292struct test_flash_attn_ext : public test_case {
61736293 const int64_t hsk; // K head size
@@ -8585,6 +8705,38 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
85858705 }
85868706 }
85878707
8708+ // TURBO_WHT tests
8709+ for (int dir : {0 , 1 }) {
8710+ for (int64_t hd : {128 , 256 , 512 }) {
8711+ for (int64_t nh : {1 , 4 , 8 }) {
8712+ test_cases.emplace_back (new test_turbo_wht (hd, nh, dir));
8713+ }
8714+ }
8715+ }
8716+
8717+ // TURBO_WHT round-trip tests (forward then inverse = identity)
8718+ for (int64_t hd : {128 , 256 , 512 }) {
8719+ for (int64_t nh : {1 , 4 , 8 }) {
8720+ test_cases.emplace_back (new test_turbo_wht_roundtrip (hd, nh));
8721+ }
8722+ }
8723+
8724+ // SET_ROWS with turbo3 destination: quantize then dequant round-trip
8725+ // Small tensors (single-dim dispatch)
8726+ for (ggml_type idx_type : {GGML_TYPE_I32, GGML_TYPE_I64}) {
8727+ for (int64_t ne0 : {128 , 256 , 512 }) {
8728+ for (int r : {1 , 4 , 7 }) {
8729+ test_cases.emplace_back (new test_set_rows_turbo3 (idx_type, ne0, 16 , r));
8730+ }
8731+ }
8732+ }
8733+ // Large tensors -- exercises 2D dispatch grid (>512 workgroups),
8734+ // matching actual inference dimensions (4 kv_heads, batch=1024+)
8735+ test_cases.emplace_back (new test_set_rows_turbo3 (GGML_TYPE_I32, 128 , 4096 , 1024 ));
8736+ test_cases.emplace_back (new test_set_rows_turbo3 (GGML_TYPE_I32, 256 , 2048 , 512 ));
8737+ test_cases.emplace_back (new test_set_rows_turbo3 (GGML_TYPE_I32, 512 , 1024 , 256 ));
8738+
8739+
85888740 for (int hsk : { 40 , 64 , 72 , 80 , 96 , 128 , 192 , 256 , 320 , 512 , 576 }) {
85898741 for (int hsv : { 40 , 64 , 72 , 80 , 96 , 128 , 192 , 256 , 512 }) {
85908742 if (hsk != 192 && hsk != 320 && hsk != 576 && hsk != hsv) continue ;
@@ -8612,8 +8764,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
86128764 for (int nb : { 1 , 3 , 32 , 75 , }) {
86138765 for (ggml_prec prec : {GGML_PREC_F32, GGML_PREC_DEFAULT}) {
86148766 if (hsk != 128 && prec == GGML_PREC_DEFAULT) continue ;
8615- for (ggml_type type_KV : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
8616- if (type_KV != GGML_TYPE_F16 && hsk != 64 && hsk != 72 ) continue ;
8767+ for (ggml_type type_KV : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0, GGML_TYPE_TURBO3_0}) {
8768+ if (type_KV == GGML_TYPE_TURBO3_0 && hsk < 128 ) continue ;
8769+ if (type_KV != GGML_TYPE_F16 && hsk != 64 && hsk != 72 && hsk != 128 ) continue ;
86178770 test_cases.emplace_back (new test_flash_attn_ext (
86188771 hsk, hsv, nh, {nr2, nr3}, kv, nb, mask, sinks, max_bias, logit_softcap, prec, type_KV));
86198772 // run fewer test cases permuted
0 commit comments