Skip to content

Commit 781cca0

Browse files
unamedkrclaude
andcommitted
A1: Dynamic buffer allocation for 3B+ model support
Replace 12 stack-allocated fixed-size arrays with dynamic buffers in tq_state_t, sized from model config at runtime: - xb_q8/xb_q8s: activation Q8 workspace (was float[4096]) - gate_vals/decay_vals: DeltaNet gates (was float[128]) - delta_sk/delta_dvec: DeltaNet workspace (was float[128]) Removes hardcoded 4096/128 limits that would overflow on 3B+ models. All existing tests pass, Qwen3.5 and Gemma3 unaffected. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent b4949fe commit 781cca0

2 files changed

Lines changed: 64 additions & 38 deletions

File tree

include/turboquant/tq_engine.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,14 @@ typedef struct {
196196
float* delta_ab; /* [delta_n_heads * 2] workspace for a,b projections */
197197
float* delta_out; /* [z_dim] workspace for output */
198198

199+
/* Dynamic workspace buffers (sized from model config, replacing stack arrays) */
200+
int8_t* xb_q8; /* [hidden_dim] pre-quantized activation for Q4 matmuls */
201+
float* xb_q8s; /* [hidden_dim/32 + 1] Q8 scales for xb_q8 */
202+
float* gate_vals; /* [delta_n_heads] DeltaNet gate values */
203+
float* decay_vals; /* [delta_n_heads] DeltaNet precomputed exp(gate) */
204+
float* delta_sk; /* [delta_value_head_dim] DeltaNet S@K workspace */
205+
float* delta_dvec; /* [delta_value_head_dim] DeltaNet delta workspace */
206+
199207
/* Quantization workspace */
200208
void* quant_key_buf; /* workspace for quantized keys */
201209
float* quant_score_buf; /* workspace for quantized attention scores */

src/engine/tq_transformer.c

Lines changed: 56 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,13 @@ tq_state_t* tq_create_state(const tq_model_config_t* config, tq_type kv_type) {
7979
s->value_cache = (float*)calloc((size_t)n_layers * kv_layer_size, sizeof(float));
8080
s->kv_cache_size = (size_t)n_layers * kv_layer_size * sizeof(float);
8181

82+
/* Dynamic workspace buffers (replacing fixed-size stack arrays).
83+
* xb_q8/xb_q8s are used in deltanet_forward, self_attn_forward, and FFN
84+
* for pre-quantizing activations to Q8 before Q4 matmuls. */
85+
int q8_blocks = (dim + 31) / 32;
86+
s->xb_q8 = (int8_t*)calloc((size_t)dim, sizeof(int8_t));
87+
s->xb_q8s = (float*)calloc((size_t)(q8_blocks + 1), sizeof(float));
88+
8289
/* DeltaNet recurrent state */
8390
if (config->delta_n_heads > 0) {
8491
int dn = config->delta_n_heads;
@@ -96,6 +103,12 @@ tq_state_t* tq_create_state(const tq_model_config_t* config, tq_type kv_type) {
96103
s->delta_z = (float*)calloc((size_t)delta_z_dim, sizeof(float));
97104
s->delta_ab = (float*)calloc((size_t)dn * 2, sizeof(float));
98105
s->delta_out = (float*)calloc((size_t)delta_z_dim, sizeof(float));
106+
107+
/* DeltaNet per-head workspace (replacing stack-allocated gate_vals/decay_vals/sk/d_vec) */
108+
s->gate_vals = (float*)calloc((size_t)dn, sizeof(float));
109+
s->decay_vals = (float*)calloc((size_t)dn, sizeof(float));
110+
s->delta_sk = (float*)calloc((size_t)dv, sizeof(float));
111+
s->delta_dvec = (float*)calloc((size_t)dv, sizeof(float));
99112
}
100113

101114
/* Quantization workspace */
@@ -129,7 +142,8 @@ tq_state_t* tq_create_state(const tq_model_config_t* config, tq_type kv_type) {
129142
/* Verify critical allocations */
130143
if (!s->x || !s->xb || !s->xb2 || !s->q || !s->k || !s->v ||
131144
!s->att || !s->hb || !s->hb2 || !s->logits ||
132-
!s->key_cache || !s->value_cache) {
145+
!s->key_cache || !s->value_cache ||
146+
!s->xb_q8 || !s->xb_q8s) {
133147
tq_free_state(s);
134148
return NULL;
135149
}
@@ -157,6 +171,12 @@ void tq_free_state(tq_state_t* state) {
157171
free(state->delta_z);
158172
free(state->delta_ab);
159173
free(state->delta_out);
174+
free(state->xb_q8);
175+
free(state->xb_q8s);
176+
free(state->gate_vals);
177+
free(state->decay_vals);
178+
free(state->delta_sk);
179+
free(state->delta_dvec);
160180
free(state->quant_key_buf);
161181
free(state->quant_score_buf);
162182
free(state->quant_key_cache);
@@ -341,23 +361,21 @@ static void deltanet_forward(tq_model_t* model, tq_state_t* s, int l) {
341361

342362
/* Pre-quantize activation to Q8 once for all Q4 projections in this layer.
343363
* This eliminates 4 redundant tq_quantize_row_q8 + malloc/free cycles. */
344-
int8_t xb_q8[4096]; /* max hidden_dim, stack-allocated for speed */
345-
float xb_q8s[128 + 1]; /* max blocks + 1 */
346364
int has_q4 = (layer->delta_in_proj_qkv_q4 != NULL);
347365
if (has_q4) {
348-
tq_quantize_row_q8(s->xb, xb_q8, xb_q8s, dim);
366+
tq_quantize_row_q8(s->xb, s->xb_q8, s->xb_q8s, dim);
349367
}
350368

351369
/* Step 1: Project input through QKV and Z */
352370
if (layer->delta_in_proj_qkv_q4)
353-
tq_matmul_q4_preq(s->delta_qkv, layer->delta_in_proj_qkv_q4, layer->delta_in_proj_qkv_q4s, xb_q8, xb_q8s, qkv_dim, dim);
371+
tq_matmul_q4_preq(s->delta_qkv, layer->delta_in_proj_qkv_q4, layer->delta_in_proj_qkv_q4s, s->xb_q8, s->xb_q8s, qkv_dim, dim);
354372
else if (layer->delta_in_proj_qkv_q8)
355373
tq_matmul_q8(s->delta_qkv, s->xb, layer->delta_in_proj_qkv_q8, layer->delta_in_proj_qkv_q8s, qkv_dim, dim);
356374
else
357375
tq_matmul(s->delta_qkv, s->xb, layer->delta_in_proj_qkv, qkv_dim, dim);
358376

359377
if (layer->delta_in_proj_z_q4)
360-
tq_matmul_q4_preq(s->delta_z, layer->delta_in_proj_z_q4, layer->delta_in_proj_z_q4s, xb_q8, xb_q8s, z_dim, dim);
378+
tq_matmul_q4_preq(s->delta_z, layer->delta_in_proj_z_q4, layer->delta_in_proj_z_q4s, s->xb_q8, s->xb_q8s, z_dim, dim);
361379
else if (layer->delta_in_proj_z_q8)
362380
tq_matmul_q8(s->delta_z, s->xb, layer->delta_in_proj_z_q8, layer->delta_in_proj_z_q8s, z_dim, dim);
363381
else
@@ -366,15 +384,15 @@ static void deltanet_forward(tq_model_t* model, tq_state_t* s, int l) {
366384
/* Step 2: Project alpha and beta */
367385
/* alpha = in_proj_a @ x -> [dn] */
368386
if (layer->delta_in_proj_a_q4)
369-
tq_matmul_q4_preq(s->delta_ab, layer->delta_in_proj_a_q4, layer->delta_in_proj_a_q4s, xb_q8, xb_q8s, dn, dim);
387+
tq_matmul_q4_preq(s->delta_ab, layer->delta_in_proj_a_q4, layer->delta_in_proj_a_q4s, s->xb_q8, s->xb_q8s, dn, dim);
370388
else if (layer->delta_in_proj_a_q8)
371389
tq_matmul_q8(s->delta_ab, s->xb, layer->delta_in_proj_a_q8, layer->delta_in_proj_a_q8s, dn, dim);
372390
else
373391
tq_matmul(s->delta_ab, s->xb, layer->delta_in_proj_a, dn, dim);
374392

375393
/* beta = sigmoid(in_proj_b @ x) -> [dn] */
376394
if (layer->delta_in_proj_b_q4)
377-
tq_matmul_q4_preq(s->delta_ab + dn, layer->delta_in_proj_b_q4, layer->delta_in_proj_b_q4s, xb_q8, xb_q8s, dn, dim);
395+
tq_matmul_q4_preq(s->delta_ab + dn, layer->delta_in_proj_b_q4, layer->delta_in_proj_b_q4s, s->xb_q8, s->xb_q8s, dn, dim);
378396
else if (layer->delta_in_proj_b_q8)
379397
tq_matmul_q8(s->delta_ab + dn, s->xb, layer->delta_in_proj_b_q8, layer->delta_in_proj_b_q8s, dn, dim);
380398
else
@@ -387,8 +405,8 @@ static void deltanet_forward(tq_model_t* model, tq_state_t* s, int l) {
387405
* gate = softplus(alpha + dt_bias) * (-exp(A_log))
388406
* exp(gate) is the per-step multiplicative decay (< 1).
389407
* We precompute both gate_vals and exp(gate) to avoid repeated exp calls. */
390-
float gate_vals[128]; /* max 128 heads, stack-allocated for speed */
391-
float decay_vals[128]; /* precomputed exp(gate) per head */
408+
float* gate_vals = s->gate_vals;
409+
float* decay_vals = s->decay_vals;
392410
for (int h = 0; h < dn; h++) {
393411
float alpha_biased = s->delta_ab[h] + layer->delta_dt_bias[h];
394412
/* softplus: log(1 + exp(x)). For large x, softplus(x) ~ x */
@@ -445,7 +463,7 @@ static void deltanet_forward(tq_model_t* model, tq_state_t* s, int l) {
445463
/* NEON-optimized: fused decay + sk computation.
446464
* For each row i of state: decay state, accumulate sk.
447465
* sk[j] = sum_i(S[i,j] * K[i]) after decay */
448-
float sk[128] __attribute__((aligned(16)));
466+
float* sk = s->delta_sk;
449467
memset(sk, 0, (size_t)dv * sizeof(float));
450468

451469
float32x4_t vdecay = vdupq_n_f32(decay);
@@ -469,7 +487,7 @@ static void deltanet_forward(tq_model_t* model, tq_state_t* s, int l) {
469487
}
470488

471489
/* Delta: d = beta * (V - sk) */
472-
float d_vec[128] __attribute__((aligned(16)));
490+
float* d_vec = s->delta_dvec;
473491
float32x4_t vbeta = vdupq_n_f32(beta_h);
474492
{
475493
int j = 0;
@@ -518,7 +536,7 @@ static void deltanet_forward(tq_model_t* model, tq_state_t* s, int l) {
518536
}
519537

520538
/* Compute sk */
521-
float sk[128];
539+
float* sk = s->delta_sk;
522540
for (int j = 0; j < dv; j++) {
523541
float sum = 0.0f;
524542
for (int i = 0; i < dk; i++) {
@@ -528,7 +546,7 @@ static void deltanet_forward(tq_model_t* model, tq_state_t* s, int l) {
528546
}
529547

530548
/* Delta */
531-
float d_vec[128];
549+
float* d_vec = s->delta_dvec;
532550
for (int j = 0; j < dv; j++) {
533551
d_vec[j] = beta_h * (vh[j] - sk[j]);
534552
}
@@ -643,11 +661,9 @@ static void self_attn_forward(tq_model_t* model, tq_state_t* s, int l, int pos)
643661

644662
/* Pre-quantize activation to Q8 once for all Q4 projections in this layer.
645663
* This eliminates redundant tq_quantize_row_q8 + malloc/free in each matmul_q4 call. */
646-
int8_t xb_q8[4096]; /* max hidden_dim */
647-
float xb_q8s[128 + 1];
648664
int has_q4 = (layer->wq_q4 != NULL);
649665
if (has_q4) {
650-
tq_quantize_row_q8(s->xb, xb_q8, xb_q8s, dim);
666+
tq_quantize_row_q8(s->xb, s->xb_q8, s->xb_q8s, dim);
651667
}
652668

653669
/* QKV projections.
@@ -656,12 +672,13 @@ static void self_attn_forward(tq_model_t* model, tq_state_t* s, int l, int pos)
656672
float* gate_q = NULL;
657673
if (c->attn_output_gate) {
658674
int qg_dim = n_heads * head_dim * 2;
659-
if (layer->wq_q4)
660-
tq_matmul_q4_preq(s->xb2, layer->wq_q4, layer->wq_q4s, xb_q8, xb_q8s, qg_dim, dim);
661-
else if (layer->wq_q8)
675+
if (layer->wq_q4) {
676+
tq_matmul_q4_preq(s->xb2, layer->wq_q4, layer->wq_q4s, s->xb_q8, s->xb_q8s, qg_dim, dim);
677+
} else if (layer->wq_q8) {
662678
tq_matmul_q8(s->xb2, s->xb, layer->wq_q8, layer->wq_q8s, qg_dim, dim);
663-
else
679+
} else {
664680
tq_matmul(s->xb2, s->xb, layer->wq, qg_dim, dim);
681+
}
665682
/* Deinterleave: extract Q and gate from interleaved layout */
666683
gate_q = s->xb2;
667684
float* gate_tmp = s->att;
@@ -675,25 +692,28 @@ static void self_attn_forward(tq_model_t* model, tq_state_t* s, int l, int pos)
675692
}
676693
gate_q = gate_tmp;
677694
} else {
678-
if (layer->wq_q4)
679-
tq_matmul_q4_preq(s->q, layer->wq_q4, layer->wq_q4s, xb_q8, xb_q8s, n_heads * head_dim, dim);
680-
else if (layer->wq_q8)
695+
if (layer->wq_q4) {
696+
tq_matmul_q4_preq(s->q, layer->wq_q4, layer->wq_q4s, s->xb_q8, s->xb_q8s, n_heads * head_dim, dim);
697+
} else if (layer->wq_q8) {
681698
tq_matmul_q8(s->q, s->xb, layer->wq_q8, layer->wq_q8s, n_heads * head_dim, dim);
682-
else
699+
} else {
683700
tq_matmul(s->q, s->xb, layer->wq, n_heads * head_dim, dim);
701+
}
684702
}
685-
if (layer->wk_q4)
686-
tq_matmul_q4_preq(s->k, layer->wk_q4, layer->wk_q4s, xb_q8, xb_q8s, kv_dim, dim);
687-
else if (layer->wk_q8)
703+
if (layer->wk_q4) {
704+
tq_matmul_q4_preq(s->k, layer->wk_q4, layer->wk_q4s, s->xb_q8, s->xb_q8s, kv_dim, dim);
705+
} else if (layer->wk_q8) {
688706
tq_matmul_q8(s->k, s->xb, layer->wk_q8, layer->wk_q8s, kv_dim, dim);
689-
else
707+
} else {
690708
tq_matmul(s->k, s->xb, layer->wk, kv_dim, dim);
691-
if (layer->wv_q4)
692-
tq_matmul_q4_preq(s->v, layer->wv_q4, layer->wv_q4s, xb_q8, xb_q8s, kv_dim, dim);
693-
else if (layer->wv_q8)
709+
}
710+
if (layer->wv_q4) {
711+
tq_matmul_q4_preq(s->v, layer->wv_q4, layer->wv_q4s, s->xb_q8, s->xb_q8s, kv_dim, dim);
712+
} else if (layer->wv_q8) {
694713
tq_matmul_q8(s->v, s->xb, layer->wv_q8, layer->wv_q8s, kv_dim, dim);
695-
else
714+
} else {
696715
tq_matmul(s->v, s->xb, layer->wv, kv_dim, dim);
716+
}
697717

698718
/* Apply QK-norm if present (per-head RMSNorm) */
699719
if (layer->q_norm) {
@@ -1024,14 +1044,12 @@ float* tq_forward(tq_model_t* model, tq_state_t* s, int token, int pos) {
10241044

10251045
/* Pre-quantize xb for gate+up Q4 projections (same input, 2 matmuls) */
10261046
if (layer->w_gate_q4) {
1027-
int8_t ffn_xb_q8[4096];
1028-
float ffn_xb_q8s[128 + 1];
1029-
tq_quantize_row_q8(s->xb, ffn_xb_q8, ffn_xb_q8s, dim);
1047+
tq_quantize_row_q8(s->xb, s->xb_q8, s->xb_q8s, dim);
10301048

10311049
tq_matmul_q4_preq(s->hb, layer->w_gate_q4, layer->w_gate_q4s,
1032-
ffn_xb_q8, ffn_xb_q8s, c->intermediate_dim, dim);
1050+
s->xb_q8, s->xb_q8s, c->intermediate_dim, dim);
10331051
tq_matmul_q4_preq(s->hb2, layer->w_up_q4, layer->w_up_q4s,
1034-
ffn_xb_q8, ffn_xb_q8s, c->intermediate_dim, dim);
1052+
s->xb_q8, s->xb_q8s, c->intermediate_dim, dim);
10351053
} else {
10361054
if (layer->w_gate_q8) {
10371055
tq_matmul_q8(s->hb, s->xb, layer->w_gate_q8, layer->w_gate_q8s, c->intermediate_dim, dim);

0 commit comments

Comments
 (0)