@@ -6,6 +6,7 @@ extern void quantize_row_turbo3_0_ref(const float * x, void * y, long long k);
66extern void dequantize_row_turbo3_0 (const void * x , float * y , long long k );
77extern void quantize_row_turbo4_0_ref (const float * x , void * y , long long k );
88extern void dequantize_row_turbo4_0 (const void * x , float * y , long long k );
9+ extern void turbo_cpu_fwht_inverse (float * x , int group_size );
910
1011int main (void ) {
1112 const int d = 128 ;
@@ -15,11 +16,17 @@ int main(void) {
1516
1617 printf ("=== TurboQuant C Round-Trip Test ===\n\n" );
1718
18- /* Test 1: basis vector */
19+ /* Test 1: basis vector
20+ *
21+ * dequantize_row_turbo3_0 leaves output in the WHT-rotated domain (Q is
22+ * also rotated by the graph, so <Q_rot, K_rot> yields correct attention
23+ * scores without an explicit inverse). To verify the round-trip, apply
24+ * the inverse WHT before comparing against the original input. */
1925 memset (input , 0 , sizeof (input ));
2026 input [0 ] = 1.0f ;
2127 quantize_row_turbo3_0_ref (input , buf , d );
2228 dequantize_row_turbo3_0 (buf , output , d );
29+ turbo_cpu_fwht_inverse (output , d );
2330 printf ("Test 1 (turbo3): e0 = [1, 0, ...]\n" );
2431 printf (" In: [%.6f, %.6f, %.6f, %.6f]\n" , input [0 ], input [1 ], input [2 ], input [3 ]);
2532 printf (" Out: [%.6f, %.6f, %.6f, %.6f]\n" , output [0 ], output [1 ], output [2 ], output [3 ]);
@@ -31,17 +38,23 @@ int main(void) {
3138 for (int i = 0 ; i < d ; i ++ ) input [i ] = sinf (i * 0.1f + 0.5f ) * 10.0f ;
3239 quantize_row_turbo3_0_ref (input , buf , d );
3340 dequantize_row_turbo3_0 (buf , output , d );
41+ turbo_cpu_fwht_inverse (output , d );
3442 printf ("Test 2 (turbo3): sin*10\n" );
3543 printf (" In: [%.4f, %.4f, %.4f, %.4f]\n" , input [0 ], input [1 ], input [2 ], input [3 ]);
3644 printf (" Out: [%.4f, %.4f, %.4f, %.4f]\n" , output [0 ], output [1 ], output [2 ], output [3 ]);
3745 mse = cosv = ni = no = 0 ;
3846 for (int i = 0 ; i < d ; i ++ ) { mse += (input [i ]- output [i ])* (input [i ]- output [i ]); cosv += input [i ]* output [i ]; ni += input [i ]* input [i ]; no += output [i ]* output [i ]; }
3947 printf (" MSE=%.8f Cosine=%.6f InNorm=%.2f OutNorm=%.2f\n\n" , mse /d , cosv /sqrtf (ni )/sqrtf (no ), sqrtf (ni ), sqrtf (no ));
4048
41- /* Test 3: turbo4 */
49+ /* Test 3: turbo4
50+ *
51+ * Same convention as turbo3: dequant leaves output in the rotated domain
52+ * (see comment in dequantize_row_turbo4_0 @ ggml-turbo-quant.c). Apply
53+ * the inverse WHT before comparing. */
4254 for (int i = 0 ; i < d ; i ++ ) input [i ] = cosf (i * 0.2f ) * 5.0f ;
4355 quantize_row_turbo4_0_ref (input , buf , d );
4456 dequantize_row_turbo4_0 (buf , output , d );
57+ turbo_cpu_fwht_inverse (output , d );
4558 printf ("Test 3 (turbo4): cos*5\n" );
4659 printf (" In: [%.4f, %.4f, %.4f, %.4f]\n" , input [0 ], input [1 ], input [2 ], input [3 ]);
4760 printf (" Out: [%.4f, %.4f, %.4f, %.4f]\n" , output [0 ], output [1 ], output [2 ], output [3 ]);
0 commit comments