Skip to content

Commit 0759506

Browse files
committed
fix: inverse WHT in test-turbo-quant.c round-trip (#59)
1 parent 74450af commit 0759506

2 files changed

Lines changed: 46 additions & 2 deletions

File tree

ggml/src/ggml-turbo-quant.c

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,37 @@ static void turbo_cpu_fwht(float * x, int group_size) {
236236
for (int i = 0; i < group_size; i++) x[i] *= inv_sqrt * s2[i];
237237
}
238238

239+
/* ---------- CPU inverse WHT (in-place, group_size elements) ----------
240+
*
241+
* Forward is y = D(s2) * N * H * D(s1) * x (N = 1/sqrt(group_size))
242+
* H is the unnormalized Hadamard butterfly with H*H = group_size * I, so
243+
* (N*H) is self-inverse. s1 and s2 are ±1 diagonals, also self-inverse.
244+
* The inverse therefore has the same structure with s1 and s2 swapped:
245+
* x = D(s1) * N * H * D(s2) * y
246+
*/
247+
GGML_API void turbo_cpu_fwht_inverse(float * x, int group_size) {
248+
const float * s1 = turbo_cpu_s1;
249+
const float * s2 = turbo_cpu_s2;
250+
const float inv_sqrt = (group_size == 128) ? 0.08838834764831845f : 0.125f;
251+
252+
// signs2 (undoes the s2 that was applied last in the forward pass)
253+
for (int i = 0; i < group_size; i++) x[i] *= s2[i];
254+
255+
// butterfly stages (same as forward — self-inverse up to the inv_sqrt scaling below)
256+
for (int h = 1; h < group_size; h *= 2) {
257+
for (int i = 0; i < group_size; i += h * 2) {
258+
for (int j = i; j < i + h; j++) {
259+
float a = x[j], b = x[j + h];
260+
x[j] = a + b;
261+
x[j + h] = a - b;
262+
}
263+
}
264+
}
265+
266+
// normalize + signs1
267+
for (int i = 0; i < group_size; i++) x[i] *= inv_sqrt * s1[i];
268+
}
269+
239270
/* ---------- TURBO3_0: 3-bit PolarQuant with WHT rotation ---------- */
240271

241272
void quantize_row_turbo3_0_ref(const float * GGML_RESTRICT x, block_turbo3_0 * GGML_RESTRICT y, int64_t k) {

tests/test-turbo-quant.c

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ extern void quantize_row_turbo3_0_ref(const float * x, void * y, long long k);
66
extern void dequantize_row_turbo3_0(const void * x, float * y, long long k);
77
extern void quantize_row_turbo4_0_ref(const float * x, void * y, long long k);
88
extern void dequantize_row_turbo4_0(const void * x, float * y, long long k);
9+
extern void turbo_cpu_fwht_inverse(float * x, int group_size);
910

1011
int main(void) {
1112
const int d = 128;
@@ -15,11 +16,17 @@ int main(void) {
1516

1617
printf("=== TurboQuant C Round-Trip Test ===\n\n");
1718

18-
/* Test 1: basis vector */
19+
/* Test 1: basis vector
20+
*
21+
* dequantize_row_turbo3_0 leaves output in the WHT-rotated domain (Q is
22+
* also rotated by the graph, so <Q_rot, K_rot> yields correct attention
23+
* scores without an explicit inverse). To verify the round-trip, apply
24+
* the inverse WHT before comparing against the original input. */
1925
memset(input, 0, sizeof(input));
2026
input[0] = 1.0f;
2127
quantize_row_turbo3_0_ref(input, buf, d);
2228
dequantize_row_turbo3_0(buf, output, d);
29+
turbo_cpu_fwht_inverse(output, d);
2330
printf("Test 1 (turbo3): e0 = [1, 0, ...]\n");
2431
printf(" In: [%.6f, %.6f, %.6f, %.6f]\n", input[0], input[1], input[2], input[3]);
2532
printf(" Out: [%.6f, %.6f, %.6f, %.6f]\n", output[0], output[1], output[2], output[3]);
@@ -31,17 +38,23 @@ int main(void) {
3138
for (int i = 0; i < d; i++) input[i] = sinf(i*0.1f+0.5f) * 10.0f;
3239
quantize_row_turbo3_0_ref(input, buf, d);
3340
dequantize_row_turbo3_0(buf, output, d);
41+
turbo_cpu_fwht_inverse(output, d);
3442
printf("Test 2 (turbo3): sin*10\n");
3543
printf(" In: [%.4f, %.4f, %.4f, %.4f]\n", input[0], input[1], input[2], input[3]);
3644
printf(" Out: [%.4f, %.4f, %.4f, %.4f]\n", output[0], output[1], output[2], output[3]);
3745
mse = cosv = ni = no = 0;
3846
for (int i = 0; i < d; i++) { mse += (input[i]-output[i])*(input[i]-output[i]); cosv += input[i]*output[i]; ni += input[i]*input[i]; no += output[i]*output[i]; }
3947
printf(" MSE=%.8f Cosine=%.6f InNorm=%.2f OutNorm=%.2f\n\n", mse/d, cosv/sqrtf(ni)/sqrtf(no), sqrtf(ni), sqrtf(no));
4048

41-
/* Test 3: turbo4 */
49+
/* Test 3: turbo4
50+
*
51+
* Same convention as turbo3: dequant leaves output in the rotated domain
52+
* (see comment in dequantize_row_turbo4_0 @ ggml-turbo-quant.c). Apply
53+
* the inverse WHT before comparing. */
4254
for (int i = 0; i < d; i++) input[i] = cosf(i*0.2f) * 5.0f;
4355
quantize_row_turbo4_0_ref(input, buf, d);
4456
dequantize_row_turbo4_0(buf, output, d);
57+
turbo_cpu_fwht_inverse(output, d);
4558
printf("Test 3 (turbo4): cos*5\n");
4659
printf(" In: [%.4f, %.4f, %.4f, %.4f]\n", input[0], input[1], input[2], input[3]);
4760
printf(" Out: [%.4f, %.4f, %.4f, %.4f]\n", output[0], output[1], output[2], output[3]);

0 commit comments

Comments
 (0)