Skip to content

Commit c299c8f

Browse files
unamedkrclaude
andcommitted
Fix multi-block attention + wire integer kernel + real-time demo
CRITICAL FIXES: - BUG #1: Multi-block attention for head_dim > 128 (was processing only first block) Fixed in: tq_uniform.c, tq_mixed.c, tq_polar.c, tq_neon.c, tq_context.c All attention functions now iterate blocks_per_key blocks per key vector - BUG #2: Integer attention (tq_uniform_4b_attention_int_ref) now registered in TQ_TRAITS — previously the slow dequant path was always used VERIFIED ON REAL MODEL: - Qwen3.5-0.8B (head_dim=256, 2 blocks per key): cosine 0.9802 (A grade) - Was: cosine 0.0000 (completely broken for dim > 128) NEW: - tests/test_multiblock.cpp: 5 tests for dim=256, 384 multi-block - tools/tq_realtime_demo.py: End-to-end demo with actual model KV cache compression and TurboQuant attention (not PyTorch) 18/18 tests pass. All existing tests unaffected. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 7375cdb commit c299c8f

8 files changed

Lines changed: 793 additions & 254 deletions

File tree

src/backend/cpu/tq_neon.c

Lines changed: 156 additions & 166 deletions
Large diffs are not rendered by default.

src/core/tq_context.c

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -86,12 +86,19 @@ tq_status tq_quantize_keys(tq_context_t* ctx,
8686

8787
pthread_mutex_lock(&ctx->mutex);
8888

89-
size_t type_size = TQ_TRAITS[type].type_size;
89+
size_t block_size = TQ_TRAITS[type].block_size;
90+
size_t type_size = TQ_TRAITS[type].type_size;
91+
int blocks_per_key = (head_dim + (int)block_size - 1) / (int)block_size;
9092
uint8_t* dst = (uint8_t*)out;
9193

9294
for (int i = 0; i < n; i++) {
93-
qfn(keys + i * head_dim, dst, head_dim);
94-
dst += type_size;
95+
for (int b = 0; b < blocks_per_key; b++) {
96+
int offset = b * (int)block_size;
97+
int chunk = head_dim - offset;
98+
if (chunk > (int)block_size) chunk = (int)block_size;
99+
qfn(keys + i * head_dim + offset, dst, chunk);
100+
dst += type_size;
101+
}
95102
}
96103

97104
pthread_mutex_unlock(&ctx->mutex);
@@ -110,12 +117,19 @@ tq_status tq_dequantize_keys(tq_context_t* ctx,
110117
tq_dequantize_fn dfn = TQ_TRAITS[type].dequantize;
111118
if (!dfn) return TQ_ERR_NOT_IMPL;
112119

113-
size_t type_size = TQ_TRAITS[type].type_size;
120+
size_t block_size = TQ_TRAITS[type].block_size;
121+
size_t type_size = TQ_TRAITS[type].type_size;
122+
int blocks_per_key = (head_dim + (int)block_size - 1) / (int)block_size;
114123
const uint8_t* src = (const uint8_t*)quantized;
115124

116125
for (int i = 0; i < n; i++) {
117-
dfn(src, out + i * head_dim, head_dim);
118-
src += type_size;
126+
for (int b = 0; b < blocks_per_key; b++) {
127+
int offset = b * (int)block_size;
128+
int chunk = head_dim - offset;
129+
if (chunk > (int)block_size) chunk = (int)block_size;
130+
dfn(src, out + i * head_dim + offset, chunk);
131+
src += type_size;
132+
}
119133
}
120134

121135
return TQ_OK;
@@ -132,12 +146,19 @@ tq_status tq_quantize_values(tq_context_t* ctx,
132146
tq_quantize_fn qfn = TQ_TRAITS[type].quantize;
133147
if (!qfn) return TQ_ERR_NOT_IMPL;
134148

135-
size_t type_size = TQ_TRAITS[type].type_size;
149+
size_t block_size = TQ_TRAITS[type].block_size;
150+
size_t type_size = TQ_TRAITS[type].type_size;
151+
int blocks_per_key = ((int)head_dim + (int)block_size - 1) / (int)block_size;
136152
uint8_t* dst = (uint8_t*)out;
137153

138154
for (int i = 0; i < n; i++) {
139-
qfn(values + i * head_dim, dst, head_dim);
140-
dst += type_size;
155+
for (int b = 0; b < blocks_per_key; b++) {
156+
int offset = b * (int)block_size;
157+
int chunk = head_dim - offset;
158+
if (chunk > (int)block_size) chunk = (int)block_size;
159+
qfn(values + i * head_dim + offset, dst, chunk);
160+
dst += type_size;
161+
}
141162
}
142163

143164
return TQ_OK;

src/core/tq_mixed.c

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -149,12 +149,21 @@ void tq_mixed_4b8_dequantize_ref(const void* src, float* dst, int n) {
149149

150150
void tq_mixed_4b8_attention_ref(const float* query, const void* kv,
151151
float* scores, int seq_len, int head_dim) {
152-
const block_tq_mixed_4b8* blocks = (const block_tq_mixed_4b8*)kv;
152+
int blocks_per_key = (head_dim + TQ_BK - 1) / TQ_BK;
153+
const block_tq_mixed_4b8* all_blocks = (const block_tq_mixed_4b8*)kv;
154+
153155
for (int s = 0; s < seq_len; s++) {
154-
float deq[256]; /* max head_dim */
155-
tq_mixed_4b8_dequantize_ref(&blocks[s], deq, head_dim);
156156
float dot = 0;
157-
for (int d = 0; d < head_dim; d++) dot += query[d] * deq[d];
157+
for (int b = 0; b < blocks_per_key; b++) {
158+
int offset = b * TQ_BK;
159+
int chunk = (head_dim - offset > TQ_BK) ? TQ_BK : (head_dim - offset);
160+
161+
float deq[TQ_BK];
162+
tq_mixed_4b8_dequantize_ref(&all_blocks[s * blocks_per_key + b], deq, chunk);
163+
164+
for (int d = 0; d < chunk; d++)
165+
dot += query[offset + d] * deq[d];
166+
}
158167
scores[s] = dot;
159168
}
160169
}

src/core/tq_polar.c

Lines changed: 44 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -136,61 +136,57 @@ void tq_polar_dequantize_ref(const void* src, float* dst, int n) {
136136

137137
void tq_polar_attention_ref(const float* query, const void* kv_cache,
138138
float* scores, int seq_len, int head_dim) {
139-
/* Each key is one block_tq_polar covering head_dim elements.
140-
* Instead of dequantizing each key to FP32 then computing dot product,
141-
* we precompute cos/sin/radius lookup tables per block and gather by index.
142-
* This matches the Triton kernel in refs/PolarQuant/models/kernel4group.py. */
143-
const block_tq_polar* blocks = (const block_tq_polar*)kv_cache;
144-
int pairs = head_dim / 2;
145-
if (pairs > TQ_BK / 2) pairs = TQ_BK / 2;
139+
/* Each key may span multiple blocks when head_dim > TQ_BK.
140+
* We precompute cos/sin/radius lookup tables per block and gather by index. */
141+
int blocks_per_key = (head_dim + TQ_BK - 1) / TQ_BK;
142+
const block_tq_polar* all_blocks = (const block_tq_polar*)kv_cache;
146143

147144
/* Theta uses 2 bits (4 levels), rho uses 2 bits (4 levels) */
148145
const int theta_levels = 4;
149146
const int rho_levels = 4;
150147

151148
for (int s = 0; s < seq_len; s++) {
152-
const block_tq_polar* block = &blocks[s];
153-
154-
/* Decode block parameters from FP16 */
155-
float tscale = fp16_to_fp32(block->tscale);
156-
float tmin = fp16_to_fp32(block->tmn);
157-
float rscale = fp16_to_fp32(block->rscale);
158-
float rmin = fp16_to_fp32(block->rmn);
159-
160-
/* Step 1: Precompute theta lookup tables
161-
* For quantization level q: theta = tmin + (q + 0.5) * tscale
162-
* Using floor-based quantization with bin-centered reconstruction
163-
* matching the Triton reference kernel. */
164-
float cos_lut[4], sin_lut[4];
165-
for (int q = 0; q < theta_levels; q++) {
166-
float theta = tmin + ((float)q + 0.5f) * tscale;
167-
cos_lut[q] = cosf(theta);
168-
sin_lut[q] = sinf(theta);
169-
}
170-
171-
/* Step 2: Precompute radius lookup table */
172-
float radius_lut[4];
173-
for (int q = 0; q < rho_levels; q++) {
174-
radius_lut[q] = rmin + ((float)q + 0.5f) * rscale;
175-
}
176-
177-
/* Step 3: For each pair, gather from LUT by index and accumulate */
178149
float score = 0.0f;
179-
for (int i = 0; i < pairs; i++) {
180-
/* Extract packed indices (same layout as quantize/dequantize) */
181-
uint8_t byte = block->indices[i / 2];
182-
uint8_t packed = (i % 2 == 0) ? (byte & 0x0F) : (byte >> 4);
183-
int tq = packed & 0x03;
184-
int rq = (packed >> 2) & 0x03;
185-
186-
/* Dot product contribution from this pair:
187-
* key[2i] = radius * cos(theta)
188-
* key[2i+1] = radius * sin(theta)
189-
* contrib = query[2i] * radius * cos(theta) + query[2i+1] * radius * sin(theta)
190-
* = radius * (query[2i] * cos(theta) + query[2i+1] * sin(theta)) */
191-
float contrib = query[2 * i] * cos_lut[tq] + query[2 * i + 1] * sin_lut[tq];
192-
contrib *= radius_lut[rq];
193-
score += contrib;
150+
151+
for (int blk = 0; blk < blocks_per_key; blk++) {
152+
int offset = blk * TQ_BK;
153+
int chunk = (head_dim - offset > TQ_BK) ? TQ_BK : (head_dim - offset);
154+
int pairs = chunk / 2;
155+
156+
const block_tq_polar* block = &all_blocks[s * blocks_per_key + blk];
157+
158+
/* Decode block parameters from FP16 */
159+
float tscale = fp16_to_fp32(block->tscale);
160+
float tmin = fp16_to_fp32(block->tmn);
161+
float rscale = fp16_to_fp32(block->rscale);
162+
float rmin = fp16_to_fp32(block->rmn);
163+
164+
/* Precompute theta lookup tables */
165+
float cos_lut[4], sin_lut[4];
166+
for (int q = 0; q < theta_levels; q++) {
167+
float theta = tmin + ((float)q + 0.5f) * tscale;
168+
cos_lut[q] = cosf(theta);
169+
sin_lut[q] = sinf(theta);
170+
}
171+
172+
/* Precompute radius lookup table */
173+
float radius_lut[4];
174+
for (int q = 0; q < rho_levels; q++) {
175+
radius_lut[q] = rmin + ((float)q + 0.5f) * rscale;
176+
}
177+
178+
/* For each pair, gather from LUT by index and accumulate */
179+
for (int i = 0; i < pairs; i++) {
180+
uint8_t byte = block->indices[i / 2];
181+
uint8_t packed = (i % 2 == 0) ? (byte & 0x0F) : (byte >> 4);
182+
int tq = packed & 0x03;
183+
int rq = (packed >> 2) & 0x03;
184+
185+
float contrib = query[offset + 2 * i] * cos_lut[tq]
186+
+ query[offset + 2 * i + 1] * sin_lut[tq];
187+
contrib *= radius_lut[rq];
188+
score += contrib;
189+
}
194190
}
195191

196192
scores[s] = score;

src/core/tq_traits.c

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ extern void tq_uniform_4b_quantize_ref(const float* src, void* dst, int n);
2121
extern void tq_uniform_4b_dequantize_ref(const void* src, float* dst, int n);
2222
extern void tq_uniform_4b_attention_ref(const float* query, const void* kv,
2323
float* scores, int seq_len, int head_dim);
24+
extern void tq_uniform_4b_attention_int_ref(const float* query, const void* kv,
25+
float* scores, int seq_len, int head_dim);
2426
extern void tq_uniform_2b_quantize_ref(const float* src, void* dst, int n);
2527
extern void tq_uniform_2b_dequantize_ref(const void* src, float* dst, int n);
2628
extern void tq_uniform_2b_attention_ref(const float* query, const void* kv,
@@ -89,7 +91,7 @@ const tq_type_traits_t TQ_TRAITS[TQ_TYPE_COUNT] = {
8991
.bpe = (float)sizeof(block_tq_uniform_4b) * 8.0f / TQ_BK,
9092
.quantize = tq_uniform_4b_quantize_ref,
9193
.dequantize = tq_uniform_4b_dequantize_ref,
92-
.attention = tq_uniform_4b_attention_ref,
94+
.attention = tq_uniform_4b_attention_int_ref,
9395
.residual_type = TQ_TYPE_COUNT,
9496
},
9597
[TQ_TYPE_UNIFORM_2B] = {

src/core/tq_uniform.c

Lines changed: 51 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -182,40 +182,56 @@ void tq_uniform_4b_attention_int_ref(const float* query, const void* kv,
182182
float q_scale, q_sum;
183183
tq_quantize_query_q8(query, q8, &q_scale, &q_sum, head_dim);
184184

185-
const block_tq_uniform_4b* blocks = (const block_tq_uniform_4b*)kv;
185+
int blocks_per_key = (head_dim + TQ_BK - 1) / TQ_BK;
186+
const block_tq_uniform_4b* all_blocks = (const block_tq_uniform_4b*)kv;
186187

187188
for (int s = 0; s < seq_len; s++) {
188-
float k_scale = uni_fp16_to_fp32(blocks[s].scale);
189-
float k_zp = uni_fp16_to_fp32(blocks[s].zero_point);
190-
float k_offset = k_zp + 0.5f * k_scale; /* bin centering */
191-
192-
/* Step 2: Integer dot product (no dequantize!) */
193-
int32_t isum = 0;
194-
for (int i = 0; i < head_dim / 2; i++) {
195-
uint8_t packed = blocks[s].qs[i];
196-
int32_t q4_lo = (int32_t)(packed & 0x0F); /* low nibble [0,15] */
197-
int32_t q4_hi = (int32_t)(packed >> 4); /* high nibble [0,15] */
198-
199-
isum += q4_lo * (int32_t)q8[2*i];
200-
isum += q4_hi * (int32_t)q8[2*i + 1];
189+
float score = 0;
190+
for (int b = 0; b < blocks_per_key; b++) {
191+
int offset = b * TQ_BK;
192+
int chunk = (head_dim - offset > TQ_BK) ? TQ_BK : (head_dim - offset);
193+
const block_tq_uniform_4b* block = &all_blocks[s * blocks_per_key + b];
194+
195+
float k_scale = uni_fp16_to_fp32(block->scale);
196+
float k_zp = uni_fp16_to_fp32(block->zero_point);
197+
198+
/* Integer dot product (no dequantize!) */
199+
int32_t isum = 0;
200+
for (int i = 0; i < chunk / 2; i++) {
201+
uint8_t packed = block->qs[i];
202+
isum += (int32_t)(packed & 0x0F) * (int32_t)q8[offset + 2*i];
203+
isum += (int32_t)(packed >> 4) * (int32_t)q8[offset + 2*i + 1];
204+
}
205+
206+
/* Partial query sum for this block's zero-point correction */
207+
float block_q_sum = 0;
208+
for (int d = 0; d < chunk; d++) block_q_sum += query[offset + d];
209+
210+
score += (float)isum * k_scale * q_scale + (k_zp + 0.5f * k_scale) * block_q_sum;
201211
}
202-
203-
/* Step 3: Convert to float ONCE with combined scale
204-
* dot ~ k_scale * q_scale * isum + k_offset * q_sum */
205-
scores[s] = (float)isum * k_scale * q_scale + k_offset * q_sum;
212+
scores[s] = score;
206213
}
207214
}
208215

209216
/* ---------- Uniform 4-bit attention (dequantize + dot product) ---------- */
210217

211218
void tq_uniform_4b_attention_ref(const float* query, const void* kv,
212219
float* scores, int seq_len, int head_dim) {
213-
const block_tq_uniform_4b* blocks = (const block_tq_uniform_4b*)kv;
220+
int blocks_per_key = (head_dim + TQ_BK - 1) / TQ_BK;
221+
const block_tq_uniform_4b* all_blocks = (const block_tq_uniform_4b*)kv;
222+
214223
for (int s = 0; s < seq_len; s++) {
215-
float deq[256]; /* max head_dim */
216-
tq_uniform_4b_dequantize_ref(&blocks[s], deq, head_dim);
217224
float dot = 0;
218-
for (int d = 0; d < head_dim; d++) dot += query[d] * deq[d];
225+
for (int b = 0; b < blocks_per_key; b++) {
226+
int offset = b * TQ_BK;
227+
int chunk = (head_dim - offset > TQ_BK) ? TQ_BK : (head_dim - offset);
228+
229+
float deq[TQ_BK];
230+
tq_uniform_4b_dequantize_ref(&all_blocks[s * blocks_per_key + b], deq, chunk);
231+
232+
for (int d = 0; d < chunk; d++)
233+
dot += query[offset + d] * deq[d];
234+
}
219235
scores[s] = dot;
220236
}
221237
}
@@ -224,12 +240,21 @@ void tq_uniform_4b_attention_ref(const float* query, const void* kv,
224240

225241
void tq_uniform_2b_attention_ref(const float* query, const void* kv,
226242
float* scores, int seq_len, int head_dim) {
227-
const block_tq_uniform_2b* blocks = (const block_tq_uniform_2b*)kv;
243+
int blocks_per_key = (head_dim + TQ_BK - 1) / TQ_BK;
244+
const block_tq_uniform_2b* all_blocks = (const block_tq_uniform_2b*)kv;
245+
228246
for (int s = 0; s < seq_len; s++) {
229-
float deq[256]; /* max head_dim */
230-
tq_uniform_2b_dequantize_ref(&blocks[s], deq, head_dim);
231247
float dot = 0;
232-
for (int d = 0; d < head_dim; d++) dot += query[d] * deq[d];
248+
for (int b = 0; b < blocks_per_key; b++) {
249+
int offset = b * TQ_BK;
250+
int chunk = (head_dim - offset > TQ_BK) ? TQ_BK : (head_dim - offset);
251+
252+
float deq[TQ_BK];
253+
tq_uniform_2b_dequantize_ref(&all_blocks[s * blocks_per_key + b], deq, chunk);
254+
255+
for (int d = 0; d < chunk; d++)
256+
dot += query[offset + d] * deq[d];
257+
}
233258
scores[s] = dot;
234259
}
235260
}

0 commit comments

Comments
 (0)