Skip to content

Commit 32c363a

Browse files
unamedkrclaude
andcommitted
Implement true TurboQuant algorithm for KV cache (paper faithful)
New files: - src/core/tq_codebook.c: Lloyd-Max optimal Gaussian codebook (1-4 bit) - src/core/tq_turbo_kv.c: Full pipeline (RHT + codebook + QJL residual) - tests/test_turbo_kv.cpp: 17 tests (roundtrip, attention, comparison) Algorithm (from arXiv 2504.19874): 1. Normalize input, apply Random Hadamard Transform 2. Scalar quantize with optimal Gaussian codebook (b-1 bits) 3. Compute residual, apply QJL 1-bit sign hash 4. Attention: MSE dot product + QJL correction (unbiased estimator) Two variants: TQ_TYPE_TURBO_KV_3B (2+1 bits), TQ_TYPE_TURBO_KV_4B (3+1 bits) Codebook MSE within 1.18x of paper's theoretical bound. 21/21 tests pass, zero warnings. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 8b5e096 commit 32c363a

9 files changed

Lines changed: 1175 additions & 10 deletions

File tree

docs/papers/2504.19874v1.pdf

842 KB
Binary file not shown.
File renamed without changes.
File renamed without changes.

include/turboquant/tq_types.h

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,9 @@ typedef enum {
4949
TQ_TYPE_UNIFORM_4B= 5, /* Min-Max uniform 4-bit */
5050
TQ_TYPE_UNIFORM_2B= 6, /* Min-Max uniform 2-bit */
5151
TQ_TYPE_MIXED_4B8 = 7, /* Mixed: 4-bit base + fp16 outliers */
52-
TQ_TYPE_COUNT = 8
52+
TQ_TYPE_TURBO_KV_3B = 8, /* TurboQuant KV: 2-bit codebook + 1-bit QJL residual */
53+
TQ_TYPE_TURBO_KV_4B = 9, /* TurboQuant KV: 3-bit codebook + 1-bit QJL residual */
54+
TQ_TYPE_COUNT = 10
5355
} tq_type;
5456

5557
/* ============================================================
@@ -175,6 +177,31 @@ typedef struct {
175177
}
176178
#endif
177179

180+
/* TurboQuant KV cache block: RHT + Lloyd-Max codebook + QJL residual
181+
* 3-bit variant: 2-bit codebook (4 levels) + 1-bit QJL sign hash
182+
* Block covers TQ_BK elements (128).
183+
* Layout: norm(2) + residual_norm(2) + rht_seed(4) + mse_2bit(32) + qjl_signs(16) = 56 bytes
184+
*/
185+
typedef struct {
186+
uint16_t norm; /* L2 norm of original vector (fp16) */
187+
uint16_t residual_norm; /* L2 norm of residual after MSE (fp16) */
188+
uint32_t rht_seed; /* RHT random seed for this block */
189+
uint8_t mse_indices[TQ_BK / 4]; /* 2-bit packed codebook indices (32B) */
190+
uint8_t qjl_signs[TQ_BK / 8]; /* 1-bit QJL sign hash on residual (16B) */
191+
} block_tq_turbo_kv_3b;
192+
193+
/* TurboQuant KV cache block: 4-bit variant
194+
* 3-bit codebook (8 levels) + 1-bit QJL sign hash
195+
* Layout: norm(2) + residual_norm(2) + rht_seed(4) + mse_3bit(48) + qjl_signs(16) = 72 bytes
196+
*/
197+
typedef struct {
198+
uint16_t norm; /* L2 norm of original vector (fp16) */
199+
uint16_t residual_norm; /* L2 norm of residual after MSE (fp16) */
200+
uint32_t rht_seed; /* RHT random seed for this block */
201+
uint8_t mse_indices[TQ_BK * 3 / 8]; /* 3-bit packed codebook indices (48B) */
202+
uint8_t qjl_signs[TQ_BK / 8]; /* 1-bit QJL sign hash on residual (16B) */
203+
} block_tq_turbo_kv_4b;
204+
178205
/* ============================================================
179206
* Block size verification (compile-time, C/C++ compatible)
180207
* Uses negative-size array trick for universal compatibility.
@@ -187,5 +214,7 @@ TQ_CHECK_SIZE(block_tq_qjl, 4 + TQ_SKETCH_DIM / 8 + TQ_OUTLIERS);
187214
TQ_CHECK_SIZE(block_tq_uniform_4b, 4 + TQ_BK / 2);
188215
TQ_CHECK_SIZE(block_tq_uniform_2b, 4 + TQ_BK / 4);
189216
TQ_CHECK_SIZE(block_tq_mixed_4b8, 4 + TQ_MIXED_OUTLIERS + TQ_MIXED_OUTLIERS * 2 + TQ_BK / 2);
217+
TQ_CHECK_SIZE(block_tq_turbo_kv_3b, 8 + TQ_BK / 4 + TQ_BK / 8);
218+
TQ_CHECK_SIZE(block_tq_turbo_kv_4b, 8 + TQ_BK * 3 / 8 + TQ_BK / 8);
190219

191220
#endif /* TQ_TYPES_H */

integrations/llamacpp/tq_kv_cache.cpp

Lines changed: 41 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,10 @@ enum {
3939
GGML_TYPE_TQ_TURBO_4B = GGML_TYPE_TQ_BASE + 4,
4040
GGML_TYPE_TQ_UNIFORM_4B = GGML_TYPE_TQ_BASE + 5,
4141
GGML_TYPE_TQ_UNIFORM_2B = GGML_TYPE_TQ_BASE + 6,
42-
GGML_TYPE_TQ_MIXED_4B8 = GGML_TYPE_TQ_BASE + 7,
43-
GGML_TYPE_TQ_COUNT = 8,
42+
GGML_TYPE_TQ_MIXED_4B8 = GGML_TYPE_TQ_BASE + 7,
43+
GGML_TYPE_TQ_TURBO_KV_3B = GGML_TYPE_TQ_BASE + 8,
44+
GGML_TYPE_TQ_TURBO_KV_4B = GGML_TYPE_TQ_BASE + 9,
45+
GGML_TYPE_TQ_COUNT = 10,
4446
};
4547

4648
/* ============================================================
@@ -56,7 +58,9 @@ static int tq_to_ggml_type(tq_type type) {
5658
case TQ_TYPE_TURBO_4B: return GGML_TYPE_TQ_TURBO_4B;
5759
case TQ_TYPE_UNIFORM_4B: return GGML_TYPE_TQ_UNIFORM_4B;
5860
case TQ_TYPE_UNIFORM_2B: return GGML_TYPE_TQ_UNIFORM_2B;
59-
case TQ_TYPE_MIXED_4B8: return GGML_TYPE_TQ_MIXED_4B8;
61+
case TQ_TYPE_MIXED_4B8: return GGML_TYPE_TQ_MIXED_4B8;
62+
case TQ_TYPE_TURBO_KV_3B: return GGML_TYPE_TQ_TURBO_KV_3B;
63+
case TQ_TYPE_TURBO_KV_4B: return GGML_TYPE_TQ_TURBO_KV_4B;
6064
default: return -1;
6165
}
6266
}
@@ -70,7 +74,9 @@ static tq_type ggml_to_tq_type(int ggml_id) {
7074
case GGML_TYPE_TQ_TURBO_4B: return TQ_TYPE_TURBO_4B;
7175
case GGML_TYPE_TQ_UNIFORM_4B: return TQ_TYPE_UNIFORM_4B;
7276
case GGML_TYPE_TQ_UNIFORM_2B: return TQ_TYPE_UNIFORM_2B;
73-
case GGML_TYPE_TQ_MIXED_4B8: return TQ_TYPE_MIXED_4B8;
77+
case GGML_TYPE_TQ_MIXED_4B8: return TQ_TYPE_MIXED_4B8;
78+
case GGML_TYPE_TQ_TURBO_KV_3B: return TQ_TYPE_TURBO_KV_3B;
79+
case GGML_TYPE_TQ_TURBO_KV_4B: return TQ_TYPE_TURBO_KV_4B;
7480
default: return TQ_TYPE_COUNT;
7581
}
7682
}
@@ -130,7 +136,9 @@ TQ_GGML_WRAPPERS(turbo_3b, TQ_TYPE_TURBO_3B)
130136
TQ_GGML_WRAPPERS(turbo_4b, TQ_TYPE_TURBO_4B)
131137
TQ_GGML_WRAPPERS(uniform_4b, TQ_TYPE_UNIFORM_4B)
132138
TQ_GGML_WRAPPERS(uniform_2b, TQ_TYPE_UNIFORM_2B)
133-
TQ_GGML_WRAPPERS(mixed_4b8, TQ_TYPE_MIXED_4B8)
139+
TQ_GGML_WRAPPERS(mixed_4b8, TQ_TYPE_MIXED_4B8)
140+
TQ_GGML_WRAPPERS(turbo_kv_3b, TQ_TYPE_TURBO_KV_3B)
141+
TQ_GGML_WRAPPERS(turbo_kv_4b, TQ_TYPE_TURBO_KV_4B)
134142

135143
/* ============================================================
136144
* vec_dot wrappers (quantized key . FP32 query -> scalar)
@@ -178,7 +186,9 @@ TQ_GGML_VEC_DOT(turbo_3b, TQ_TYPE_TURBO_3B)
178186
TQ_GGML_VEC_DOT(turbo_4b, TQ_TYPE_TURBO_4B)
179187
TQ_GGML_VEC_DOT(uniform_4b, TQ_TYPE_UNIFORM_4B)
180188
TQ_GGML_VEC_DOT(uniform_2b, TQ_TYPE_UNIFORM_2B)
181-
TQ_GGML_VEC_DOT(mixed_4b8, TQ_TYPE_MIXED_4B8)
189+
TQ_GGML_VEC_DOT(mixed_4b8, TQ_TYPE_MIXED_4B8)
190+
TQ_GGML_VEC_DOT(turbo_kv_3b, TQ_TYPE_TURBO_KV_3B)
191+
TQ_GGML_VEC_DOT(turbo_kv_4b, TQ_TYPE_TURBO_KV_4B)
182192

183193
/* ============================================================
184194
* GGML type trait table
@@ -262,6 +272,22 @@ static const tq_ggml_type_trait TQ_GGML_TRAITS[GGML_TYPE_TQ_COUNT] = {
262272
tq_ggml_to_float_mixed_4b8,
263273
tq_ggml_vec_dot_mixed_4b8,
264274
},
275+
{
276+
"tq_turbo_kv_3b", GGML_TYPE_TQ_TURBO_KV_3B, TQ_TYPE_TURBO_KV_3B,
277+
sizeof(block_tq_turbo_kv_3b), TQ_BK,
278+
(float)sizeof(block_tq_turbo_kv_3b) * 8.0f / TQ_BK,
279+
tq_ggml_from_float_turbo_kv_3b,
280+
tq_ggml_to_float_turbo_kv_3b,
281+
tq_ggml_vec_dot_turbo_kv_3b,
282+
},
283+
{
284+
"tq_turbo_kv_4b", GGML_TYPE_TQ_TURBO_KV_4B, TQ_TYPE_TURBO_KV_4B,
285+
sizeof(block_tq_turbo_kv_4b), TQ_BK,
286+
(float)sizeof(block_tq_turbo_kv_4b) * 8.0f / TQ_BK,
287+
tq_ggml_from_float_turbo_kv_4b,
288+
tq_ggml_to_float_turbo_kv_4b,
289+
tq_ggml_vec_dot_turbo_kv_4b,
290+
},
265291
};
266292

267293
#define TQ_GGML_NUM_TYPES (sizeof(TQ_GGML_TRAITS) / sizeof(TQ_GGML_TRAITS[0]))
@@ -346,9 +372,15 @@ tq_type tq_parse_kv_cache_type(const char* arg) {
346372
{ "uniform4", TQ_TYPE_UNIFORM_4B },
347373
{ "uniform_4b", TQ_TYPE_UNIFORM_4B },
348374
{ "tq-uniform-4b",TQ_TYPE_UNIFORM_4B },
349-
{ "uniform2", TQ_TYPE_UNIFORM_2B },
350-
{ "uniform_2b", TQ_TYPE_UNIFORM_2B },
351-
{ "tq-uniform-2b",TQ_TYPE_UNIFORM_2B },
375+
{ "uniform2", TQ_TYPE_UNIFORM_2B },
376+
{ "uniform_2b", TQ_TYPE_UNIFORM_2B },
377+
{ "tq-uniform-2b", TQ_TYPE_UNIFORM_2B },
378+
{ "turbo_kv_3b", TQ_TYPE_TURBO_KV_3B },
379+
{ "tq-turbo-kv-3b", TQ_TYPE_TURBO_KV_3B },
380+
{ "turbokv3", TQ_TYPE_TURBO_KV_3B },
381+
{ "turbo_kv_4b", TQ_TYPE_TURBO_KV_4B },
382+
{ "tq-turbo-kv-4b", TQ_TYPE_TURBO_KV_4B },
383+
{ "turbokv4", TQ_TYPE_TURBO_KV_4B },
352384
};
353385

354386
for (size_t i = 0; i < sizeof(map) / sizeof(map[0]); i++) {

src/core/tq_codebook.c

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
/**
2+
* tq_codebook.c -- Optimal Gaussian Lloyd-Max codebook quantization
3+
*
4+
* Pre-computed optimal centroids for N(0,1) distribution at bit-widths 1-4.
5+
* These are the reconstruction points from the Max-Lloyd algorithm.
6+
* Decision boundaries are midpoints between consecutive centroids.
7+
*
8+
* Usage: After RHT, each coordinate is approximately N(0, 1/sqrt(d)),
9+
* so we scale by inv_std = sqrt(d) to normalize to N(0,1) before
10+
* codebook lookup, then scale back after dequantization.
11+
*/
12+
13+
#include "turboquant/turboquant.h"
14+
#include <math.h>
15+
#include <float.h>
16+
17+
/* ============================================================
18+
* Pre-computed Lloyd-Max centroids for standard normal N(0,1)
19+
* ============================================================ */
20+
21+
/* b=1 (2 levels): E[|X|] for half-normal = sqrt(2/pi) ~ 0.7979 */
22+
static const float CODEBOOK_1BIT[2] = {-0.7979f, 0.7979f};
23+
24+
/* b=2 (4 levels): optimal Lloyd-Max for N(0,1) */
25+
static const float CODEBOOK_2BIT[4] = {-1.5104f, -0.4528f, 0.4528f, 1.5104f};
26+
27+
/* b=3 (8 levels): optimal Lloyd-Max for N(0,1) */
28+
static const float CODEBOOK_3BIT[8] = {
29+
-2.1520f, -1.3440f, -0.7560f, -0.2451f,
30+
0.2451f, 0.7560f, 1.3440f, 2.1520f
31+
};
32+
33+
/* b=4 (16 levels): optimal Lloyd-Max for N(0,1) */
34+
static const float CODEBOOK_4BIT[16] = {
35+
-2.7326f, -2.0690f, -1.6180f, -1.2562f, -0.9423f, -0.6568f, -0.3881f, -0.1284f,
36+
0.1284f, 0.3881f, 0.6568f, 0.9423f, 1.2562f, 1.6180f, 2.0690f, 2.7326f
37+
};
38+
39+
/* Codebook table indexed by bits */
40+
static const float* const CODEBOOKS[5] = {
41+
NULL, /* 0 bits: unused */
42+
CODEBOOK_1BIT, /* 1 bit: 2 levels */
43+
CODEBOOK_2BIT, /* 2 bits: 4 levels */
44+
CODEBOOK_3BIT, /* 3 bits: 8 levels */
45+
CODEBOOK_4BIT /* 4 bits: 16 levels */
46+
};
47+
48+
static const int CODEBOOK_SIZES[5] = {0, 2, 4, 8, 16};
49+
50+
/* ============================================================
51+
* Codebook quantize: find nearest centroid for each element
52+
* ============================================================ */
53+
54+
void tq_codebook_quantize(const float* src, uint8_t* dst_indices,
55+
int n, int bits, float inv_std) {
56+
if (!src || !dst_indices || bits < 1 || bits > 4 || n <= 0) return;
57+
58+
const float* centroids = CODEBOOKS[bits];
59+
int n_levels = CODEBOOK_SIZES[bits];
60+
61+
for (int i = 0; i < n; i++) {
62+
/* Scale to standard normal space */
63+
float x = src[i] * inv_std;
64+
65+
/* Find nearest centroid (linear scan, optimal for small n_levels) */
66+
int best = 0;
67+
float best_dist = fabsf(x - centroids[0]);
68+
for (int c = 1; c < n_levels; c++) {
69+
float dist = fabsf(x - centroids[c]);
70+
if (dist < best_dist) {
71+
best_dist = dist;
72+
best = c;
73+
}
74+
}
75+
dst_indices[i] = (uint8_t)best;
76+
}
77+
}
78+
79+
/* ============================================================
80+
* Codebook dequantize: reconstruct from centroid lookup
81+
* ============================================================ */
82+
83+
void tq_codebook_dequantize(const uint8_t* indices, float* dst,
84+
int n, int bits, float inv_std) {
85+
if (!indices || !dst || bits < 1 || bits > 4 || n <= 0) return;
86+
87+
const float* centroids = CODEBOOKS[bits];
88+
float std_val = (inv_std > 1e-10f) ? (1.0f / inv_std) : 1.0f;
89+
90+
for (int i = 0; i < n; i++) {
91+
dst[i] = centroids[indices[i]] * std_val;
92+
}
93+
}
94+
95+
/* ============================================================
96+
* Codebook helpers: get centroids and number of levels
97+
* ============================================================ */
98+
99+
const float* tq_codebook_centroids(int bits) {
100+
if (bits < 1 || bits > 4) return NULL;
101+
return CODEBOOKS[bits];
102+
}
103+
104+
int tq_codebook_levels(int bits) {
105+
if (bits < 1 || bits > 4) return 0;
106+
return CODEBOOK_SIZES[bits];
107+
}

src/core/tq_traits.c

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,16 @@ extern void tq_mixed_4b8_dequantize_ref(const void* src, float* dst, int n);
3333
extern void tq_mixed_4b8_attention_ref(const float* query, const void* kv,
3434
float* scores, int seq_len, int head_dim);
3535

36+
extern void tq_turbo_kv_3b_quantize_ref(const float* src, void* dst, int n);
37+
extern void tq_turbo_kv_3b_dequantize_ref(const void* src, float* dst, int n);
38+
extern void tq_turbo_kv_3b_attention_ref(const float* query, const void* kv,
39+
float* scores, int seq_len, int head_dim);
40+
41+
extern void tq_turbo_kv_4b_quantize_ref(const float* src, void* dst, int n);
42+
extern void tq_turbo_kv_4b_dequantize_ref(const void* src, float* dst, int n);
43+
extern void tq_turbo_kv_4b_attention_ref(const float* query, const void* kv,
44+
float* scores, int seq_len, int head_dim);
45+
3646
const tq_type_traits_t TQ_TRAITS[TQ_TYPE_COUNT] = {
3747
[TQ_TYPE_POLAR_3B] = {
3848
.name = "polar_3b",
@@ -114,6 +124,26 @@ const tq_type_traits_t TQ_TRAITS[TQ_TYPE_COUNT] = {
114124
.attention = tq_mixed_4b8_attention_ref,
115125
.residual_type = TQ_TYPE_COUNT,
116126
},
127+
[TQ_TYPE_TURBO_KV_3B] = {
128+
.name = "turbo_kv_3b",
129+
.block_size = TQ_BK,
130+
.type_size = sizeof(block_tq_turbo_kv_3b),
131+
.bpe = (float)sizeof(block_tq_turbo_kv_3b) * 8.0f / TQ_BK,
132+
.quantize = tq_turbo_kv_3b_quantize_ref,
133+
.dequantize = tq_turbo_kv_3b_dequantize_ref,
134+
.attention = tq_turbo_kv_3b_attention_ref,
135+
.residual_type = TQ_TYPE_QJL_1B,
136+
},
137+
[TQ_TYPE_TURBO_KV_4B] = {
138+
.name = "turbo_kv_4b",
139+
.block_size = TQ_BK,
140+
.type_size = sizeof(block_tq_turbo_kv_4b),
141+
.bpe = (float)sizeof(block_tq_turbo_kv_4b) * 8.0f / TQ_BK,
142+
.quantize = tq_turbo_kv_4b_quantize_ref,
143+
.dequantize = tq_turbo_kv_4b_dequantize_ref,
144+
.attention = tq_turbo_kv_4b_attention_ref,
145+
.residual_type = TQ_TYPE_QJL_1B,
146+
},
117147
};
118148

119149
const char* tq_type_name(tq_type type) {
@@ -178,6 +208,12 @@ tq_format_spec_t tq_get_format_spec(tq_type type) {
178208
case TQ_TYPE_MIXED_4B8:
179209
spec.algorithm = TQ_ALG_MIXED; spec.key_bits = 4;
180210
spec.outlier_count = TQ_MIXED_OUTLIERS; break;
211+
case TQ_TYPE_TURBO_KV_3B:
212+
spec.algorithm = TQ_ALG_TURBO; spec.key_bits = 3;
213+
spec.flags = TQ_FLAG_HAS_RESIDUAL; break;
214+
case TQ_TYPE_TURBO_KV_4B:
215+
spec.algorithm = TQ_ALG_TURBO; spec.key_bits = 4;
216+
spec.flags = TQ_FLAG_HAS_RESIDUAL; break;
181217
default: break;
182218
}
183219
return spec;

0 commit comments

Comments
 (0)