Skip to content

Commit 21b92a3

Browse files
unamedkrclaude
andcommitted
Add Gemma 3 architecture support (WIP: forward pass debugging)
Multi-architecture engine: - Detect Gemma3 by pre_feedforward_layernorm tensor presence - GeGLU activation (gelu_pytorch_tanh) alongside SwiGLU - 4 RMSNorm per block (input, post_attn, pre_ffn, post_ffn) - Dual RoPE theta (10K local, 1M global) - Sliding window attention mask (512 tokens) - QK-norm with query_pre_attn_scalar - Embedding scaling by sqrt(hidden_dim) - Tokenizer: support Gemma-style array-pair merges format - Auto-detect gemma-3-270m-it in HF cache Model loads, tokenizer works (514K merges), 131 tok/s achieved. Output quality needs debugging — forward pass verification in progress. Qwen3.5 path unchanged, all 20 tests pass. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 732cb6a commit 21b92a3

8 files changed

Lines changed: 425 additions & 45 deletions

File tree

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,4 +39,5 @@ spec/test_vectors/*.bin
3939
refs/.venv/
4040
.cache/
4141
*.tqm
42-
.venv/
42+
.venv/
43+
refs/

include/turboquant/tq_engine.h

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,13 @@ typedef struct {
3434
/* QK-norm for self_attn (Qwen3.5 style) */
3535
int use_qk_norm; /* 1 if q_norm/k_norm weights present */
3636
int attn_output_gate; /* 1 if q_proj includes output gate (doubled q_proj output) */
37+
38+
/* Multi-architecture support */
39+
int model_type; /* 0=qwen35, 1=gemma3 */
40+
int sliding_window; /* sliding window size (512 for gemma3, 0 for unlimited) */
41+
float rope_local_base_freq; /* RoPE base freq for local/sliding layers (10000.0 for gemma3) */
42+
int n_norms_per_block; /* 2 for qwen35, 4 for gemma3 */
43+
float query_pre_attn_scalar; /* attention scaling: 1/sqrt(this) instead of 1/sqrt(head_dim), 0=use head_dim */
3744
} tq_model_config_t;
3845

3946
/* ============================================================
@@ -52,6 +59,11 @@ typedef struct {
5259
float* q_norm; /* [head_dim] QK-norm for queries */
5360
float* k_norm; /* [head_dim] QK-norm for keys */
5461

62+
/* Gemma3 extra norms (NULL for Qwen3.5) */
63+
float* post_attn_norm; /* [hidden_dim] post_attention_layernorm (Gemma3 only) */
64+
float* pre_ffn_norm; /* [hidden_dim] pre_feedforward_layernorm (Gemma3 only) */
65+
float* post_ffn_norm; /* [hidden_dim] post_feedforward_layernorm (Gemma3 only) */
66+
5567
/* SwiGLU FFN weights (present on ALL layers) */
5668
float* w_gate; /* [intermediate_dim, hidden_dim] */
5769
float* w_up; /* [intermediate_dim, hidden_dim] */
@@ -128,6 +140,9 @@ typedef struct {
128140
int n_attn_layers; /* number of layers with standard self_attn */
129141
int* attn_layer_indices; /* which layer indices have self_attn [n_attn_layers] */
130142

143+
/* Gemma3 sliding window support */
144+
int* layer_is_sliding; /* [n_layers] per-layer flag: 1=sliding, 0=global (NULL if not used) */
145+
131146
/* Q4 output weight (lm_head) — runtime quantized for fast logit projection */
132147
uint8_t* output_qs; /* [vocab_size * n_blocks * 16] Q4 packed nibbles */
133148
float* output_scales; /* [vocab_size * n_blocks] Q4 block scales */
@@ -278,9 +293,16 @@ typedef struct {
278293
int32_t n_attn_layers;
279294
int32_t attn_layer_indices[64]; /* which layers are self_attn (max 64) */
280295

296+
/* Multi-architecture support (Gemma3) */
297+
int32_t model_type; /* 0=qwen35, 1=gemma3 */
298+
int32_t sliding_window; /* sliding window size (512 for gemma3, 0=unlimited) */
299+
float rope_local_base_freq; /* RoPE base for local/sliding layers */
300+
int32_t n_norms_per_block;/* 2 for qwen35, 4 for gemma3 */
301+
float query_pre_attn_scalar; /* attention scaling (0=use head_dim) */
302+
281303
/* Padding to 512 bytes.
282-
* With pack(1): 8+32+8+16+12+8+32+260 = 376 used, 136 pad */
283-
uint8_t _pad[136];
304+
* With pack(1): 376 + 20 = 396 used, 116 pad */
305+
uint8_t _pad[116];
284306
} tqm_header_t;
285307
#pragma pack(pop)
286308

@@ -338,6 +360,7 @@ void tq_rmsnorm(float* out, const float* x, const float* weight, int n, float ep
338360
void tq_rope(float* q, float* k, int pos, int head_dim,
339361
int n_heads, int n_kv_heads, float freq_base);
340362
void tq_silu(float* x, int n);
363+
void tq_gelu_tanh(float* x, int n);
341364
void tq_softmax(float* x, int n);
342365
void tq_add(float* out, const float* a, const float* b, int n);
343366
void tq_mul(float* out, const float* a, const float* b, int n);

scripts/release_notes_v0.1.0.md

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
## TurboQuant.cpp v0.1.0 — First Release
2+
3+
Pure C LLM inference engine with KV cache compression. Matches llama.cpp single-thread speed.
4+
5+
### Highlights
6+
7+
- **82 tok/s peak** on Qwen3.5-0.8B (Q4, CPU-only, Apple Silicon)
8+
- **51 tok/s single-thread** — on par with llama.cpp (50.7 tok/s)
9+
- **7.5x KV cache compression** with 0.999 cosine similarity
10+
- **8 quantization types**: Uniform, Mixed, PolarQuant, QJL, TurboQuant
11+
- **TQM format**: pre-quantized binary model, mmap instant load (0.3s)
12+
- **Zero dependencies**: libc only, ~1MB binary
13+
- **One-command quickstart**: `bash scripts/quickstart.sh`
14+
15+
### What's Included
16+
17+
- Complete inference engine: DeltaNet + Self-Attention hybrid (Qwen3.5)
18+
- BPE tokenizer (248K vocab, embedded in TQM)
19+
- Q4 weight quantization with NEON 2-row batching
20+
- Thread pool with zero-overhead dispatch
21+
- Integer Q4×Q8 attention (ARM vdotq_s32)
22+
- 19 test suites, 135 tests
23+
- Python bindings (ctypes)
24+
- llama.cpp / vLLM integration stubs
25+
26+
### Quick Start
27+
28+
```bash
29+
git clone https://github.com/quantumaikr/TurboQuant.cpp && cd TurboQuant.cpp
30+
bash scripts/quickstart.sh "What is deep learning?"
31+
```
32+
33+
### References
34+
35+
- [TurboQuant](https://arxiv.org/abs/2504.19874) (ICLR 2026)
36+
- [QJL](https://arxiv.org/abs/2406.03482) (AAAI 2025)
37+
- [PolarQuant](https://arxiv.org/abs/2502.02617) (AISTATS 2026)

src/engine/tq_model.c

Lines changed: 158 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -880,17 +880,42 @@ static tq_model_t* tq_load_safetensors(const char* path) {
880880
model->config.intermediate_dim = model->config.hidden_dim * 4;
881881
}
882882

883+
/* Detect Gemma3 architecture by presence of pre_feedforward_layernorm */
884+
{
885+
snprintf(name_buf, sizeof(name_buf),
886+
"model.layers.0.pre_feedforward_layernorm.weight");
887+
tensor_info_t* gemma3_probe = find_tensor(tensors, n_tensors, name_buf);
888+
if (gemma3_probe) {
889+
model->config.model_type = 1; /* gemma3 */
890+
model->config.n_norms_per_block = 4;
891+
fprintf(stderr, "tq_load_model: detected Gemma3 architecture (4 norms per block)\n");
892+
} else {
893+
model->config.model_type = 0; /* qwen35 */
894+
model->config.n_norms_per_block = 2;
895+
}
896+
}
897+
883898
/* Defaults — tuned for Qwen3.5 if DeltaNet detected */
884899
model->config.max_seq_len = 4096;
885-
if (model->config.delta_n_heads > 0) {
900+
if (model->config.model_type == 1) {
901+
/* Gemma3: rope_theta=1M for global, 10K for local, rms_norm_eps=1e-6 */
902+
model->config.rope_freq_base = 1000000.0f; /* global layers */
903+
model->config.rope_local_base_freq = 10000.0f; /* sliding/local layers */
904+
model->config.rms_norm_eps = 1e-6f;
905+
model->config.partial_rotary_factor = 0.0f;
906+
model->config.sliding_window = 512;
907+
model->config.query_pre_attn_scalar = 256.0f;
908+
} else if (model->config.delta_n_heads > 0) {
886909
/* Qwen3.5 uses rope_theta=10M, rms_norm_eps=1e-6, partial_rotary=0.25 */
887910
model->config.rope_freq_base = 10000000.0f;
888911
model->config.rms_norm_eps = 1e-6f;
889912
model->config.partial_rotary_factor = 0.25f;
913+
model->config.query_pre_attn_scalar = 0.0f;
890914
} else {
891915
model->config.rope_freq_base = 10000.0f;
892916
model->config.rms_norm_eps = 1e-5f;
893917
model->config.partial_rotary_factor = 0.0f;
918+
model->config.query_pre_attn_scalar = 0.0f;
894919
}
895920

896921
/* Allocate layer weight pointers */
@@ -917,13 +942,32 @@ static tq_model_t* tq_load_safetensors(const char* path) {
917942
find_tensor(tensors, n_tensors, name_buf),
918943
&conv_buf, &conv_used, conv_capacity);
919944

920-
/* FFN norm */
945+
/* FFN norm (Qwen3.5: post_attention_layernorm used as pre-FFN norm) */
921946
snprintf(name_buf, sizeof(name_buf),
922947
"model.layers.%d.post_attention_layernorm.weight", l);
923948
layer->ffn_norm = load_tensor(data_base,
924949
find_tensor(tensors, n_tensors, name_buf),
925950
&conv_buf, &conv_used, conv_capacity);
926951

952+
/* Gemma3 extra norms: post_attn, pre_ffn, post_ffn */
953+
if (model->config.model_type == 1) {
954+
/* For Gemma3, post_attention_layernorm is applied to attn output,
955+
* not as pre-FFN norm. Store it in post_attn_norm. */
956+
layer->post_attn_norm = layer->ffn_norm;
957+
958+
snprintf(name_buf, sizeof(name_buf),
959+
"model.layers.%d.pre_feedforward_layernorm.weight", l);
960+
layer->pre_ffn_norm = load_tensor(data_base,
961+
find_tensor(tensors, n_tensors, name_buf),
962+
&conv_buf, &conv_used, conv_capacity);
963+
964+
snprintf(name_buf, sizeof(name_buf),
965+
"model.layers.%d.post_feedforward_layernorm.weight", l);
966+
layer->post_ffn_norm = load_tensor(data_base,
967+
find_tensor(tensors, n_tensors, name_buf),
968+
&conv_buf, &conv_used, conv_capacity);
969+
}
970+
927971
/* Q, K, V, O projections — only exist for self_attn layers */
928972
snprintf(name_buf, sizeof(name_buf),
929973
"model.layers.%d.self_attn.q_proj.weight", l);
@@ -1107,6 +1151,77 @@ static tq_model_t* tq_load_safetensors(const char* path) {
11071151
fprintf(stderr, "tq_load_model: applied Qwen3.5 RMSNorm +1 weight adjustment\n");
11081152
}
11091153

1154+
/* Gemma3 RMSNorm adjustment: same (1+w) scaling as Qwen3.5 */
1155+
if (model->config.model_type == 1) {
1156+
int dim_h = model->config.hidden_dim;
1157+
int head_dim_h = model->config.head_dim;
1158+
1159+
for (int l = 0; l < n_layers; l++) {
1160+
tq_layer_weights_t* layer_w = &model->layers[l];
1161+
if (layer_w->attn_norm) {
1162+
for (int i = 0; i < dim_h; i++) {
1163+
layer_w->attn_norm[i] += 1.0f;
1164+
}
1165+
}
1166+
if (layer_w->post_attn_norm) {
1167+
for (int i = 0; i < dim_h; i++) {
1168+
layer_w->post_attn_norm[i] += 1.0f;
1169+
}
1170+
}
1171+
if (layer_w->pre_ffn_norm) {
1172+
for (int i = 0; i < dim_h; i++) {
1173+
layer_w->pre_ffn_norm[i] += 1.0f;
1174+
}
1175+
}
1176+
if (layer_w->post_ffn_norm) {
1177+
for (int i = 0; i < dim_h; i++) {
1178+
layer_w->post_ffn_norm[i] += 1.0f;
1179+
}
1180+
}
1181+
if (layer_w->q_norm) {
1182+
for (int i = 0; i < head_dim_h; i++) {
1183+
layer_w->q_norm[i] += 1.0f;
1184+
}
1185+
}
1186+
if (layer_w->k_norm) {
1187+
for (int i = 0; i < head_dim_h; i++) {
1188+
layer_w->k_norm[i] += 1.0f;
1189+
}
1190+
}
1191+
}
1192+
if (model->output_norm) {
1193+
for (int i = 0; i < dim_h; i++) {
1194+
model->output_norm[i] += 1.0f;
1195+
}
1196+
}
1197+
fprintf(stderr, "tq_load_model: applied Gemma3 RMSNorm +1 weight adjustment\n");
1198+
1199+
/* Set up layer_is_sliding for Gemma3.
1200+
* Pattern: 5 sliding + 1 full, repeated. Layers 0-4=sliding, 5=full, etc.
1201+
* We detect by checking layer count modulo 6. */
1202+
model->layer_is_sliding = (int*)calloc((size_t)n_layers, sizeof(int));
1203+
if (model->layer_is_sliding) {
1204+
for (int l = 0; l < n_layers; l++) {
1205+
/* Full/global attention every 6th layer (indices 5, 11, 17, ...) */
1206+
if ((l + 1) % 6 == 0) {
1207+
model->layer_is_sliding[l] = 0; /* global */
1208+
} else {
1209+
model->layer_is_sliding[l] = 1; /* sliding */
1210+
}
1211+
}
1212+
int n_sliding = 0, n_global = 0;
1213+
for (int l = 0; l < n_layers; l++) {
1214+
if (model->layer_is_sliding[l]) {
1215+
n_sliding++;
1216+
} else {
1217+
n_global++;
1218+
}
1219+
}
1220+
fprintf(stderr, "tq_load_model: Gemma3 layer types: %d sliding, %d global\n",
1221+
n_sliding, n_global);
1222+
}
1223+
}
1224+
11101225
fprintf(stderr, "tq_load_model: loaded %d layers (%d with self_attn), "
11111226
"dim=%d, heads=%d/%d, vocab=%d\n",
11121227
model->config.n_layers, model->n_attn_layers,
@@ -1679,6 +1794,13 @@ tq_model_t* tq_load_tqm(const char* path) {
16791794
c->use_qk_norm = hdr->use_qk_norm;
16801795
c->attn_output_gate = hdr->attn_output_gate;
16811796

1797+
/* Multi-architecture fields */
1798+
c->model_type = hdr->model_type;
1799+
c->sliding_window = hdr->sliding_window;
1800+
c->rope_local_base_freq = hdr->rope_local_base_freq;
1801+
c->n_norms_per_block = hdr->n_norms_per_block;
1802+
c->query_pre_attn_scalar = hdr->query_pre_attn_scalar;
1803+
16821804
/* Attn layer indices */
16831805
model->n_attn_layers = hdr->n_attn_layers;
16841806
if (hdr->n_attn_layers > 0) {
@@ -1748,6 +1870,13 @@ tq_model_t* tq_load_tqm(const char* path) {
17481870
TQM_READ_FP32(layer->attn_norm, dim);
17491871
TQM_READ_FP32(layer->ffn_norm, dim);
17501872

1873+
/* Gemma3 extra norms */
1874+
if (c->model_type == 1) {
1875+
layer->post_attn_norm = layer->ffn_norm; /* shares storage */
1876+
TQM_READ_FP32(layer->pre_ffn_norm, dim);
1877+
TQM_READ_FP32(layer->post_ffn_norm, dim);
1878+
}
1879+
17511880
if (is_attn_layer && is_attn_layer[l]) {
17521881
/* Self-attention layer */
17531882
TQM_READ_Q4(layer->wq_q4, layer->wq_q4s, qg_dim, dim);
@@ -1814,6 +1943,20 @@ tq_model_t* tq_load_tqm(const char* path) {
18141943
model->use_q4_weights = 1;
18151944
free(is_attn_layer);
18161945

1946+
/* Set up Gemma3 layer_is_sliding from TQM */
1947+
if (c->model_type == 1 && c->sliding_window > 0) {
1948+
model->layer_is_sliding = (int*)calloc((size_t)c->n_layers, sizeof(int));
1949+
if (model->layer_is_sliding) {
1950+
for (int l = 0; l < c->n_layers; l++) {
1951+
if ((l + 1) % 6 == 0) {
1952+
model->layer_is_sliding[l] = 0; /* global */
1953+
} else {
1954+
model->layer_is_sliding[l] = 1; /* sliding */
1955+
}
1956+
}
1957+
}
1958+
}
1959+
18171960
/* Runtime Q4 quantization of lm_head (output projection) for fast logit computation.
18181961
* BF16 matmul on 248K x 1024 is slow; Q4 matmul is ~4x faster. */
18191962
if (model->output_weight_bf16) {
@@ -1982,6 +2125,12 @@ int tq_save_tqm(tq_model_t* model, const char* tokenizer_path,
19822125
hdr.use_qk_norm = c->use_qk_norm;
19832126
hdr.attn_output_gate = c->attn_output_gate;
19842127

2128+
hdr.model_type = c->model_type;
2129+
hdr.sliding_window = c->sliding_window;
2130+
hdr.rope_local_base_freq = c->rope_local_base_freq;
2131+
hdr.n_norms_per_block = c->n_norms_per_block;
2132+
hdr.query_pre_attn_scalar = c->query_pre_attn_scalar;
2133+
19852134
hdr.weight_quant = 4; /* Q4 */
19862135
hdr.embed_format = 16; /* BF16 */
19872136

@@ -2041,6 +2190,12 @@ int tq_save_tqm(tq_model_t* model, const char* tokenizer_path,
20412190
TQM_WRITE_FP32(layer->attn_norm, dim);
20422191
TQM_WRITE_FP32(layer->ffn_norm, dim);
20432192

2193+
/* Gemma3 extra norms */
2194+
if (c->model_type == 1) {
2195+
TQM_WRITE_FP32(layer->pre_ffn_norm, dim);
2196+
TQM_WRITE_FP32(layer->post_ffn_norm, dim);
2197+
}
2198+
20442199
if (is_attn_layer[l]) {
20452200
TQM_WRITE_Q4(layer->wq_q4, layer->wq_q4s, qg_dim, dim);
20462201
TQM_WRITE_Q4(layer->wk_q4, layer->wk_q4s, kv_dim, dim);
@@ -2144,6 +2299,7 @@ void tq_free_model(tq_model_t* model) {
21442299
free(model->_q8_data);
21452300
free(model->_q4_data);
21462301
free(model->attn_layer_indices);
2302+
free(model->layer_is_sliding);
21472303
free(model->layers);
21482304
free(model);
21492305
}

src/engine/tq_ops.c

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1060,6 +1060,19 @@ void tq_silu(float* x, int n) {
10601060
#endif
10611061
}
10621062

1063+
/* ============================================================
1064+
* GELU with tanh approximation (Gemma3 GeGLU activation)
1065+
* gelu_tanh(x) = 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
1066+
* ============================================================ */
1067+
void tq_gelu_tanh(float* x, int n) {
1068+
for (int i = 0; i < n; i++) {
1069+
float xi = x[i];
1070+
float x3 = xi * xi * xi;
1071+
float inner = 0.7978845608f * (xi + 0.044715f * x3);
1072+
x[i] = 0.5f * xi * (1.0f + tanhf(inner));
1073+
}
1074+
}
1075+
10631076
/* ============================================================
10641077
* Softmax: numerically stable with max subtraction
10651078
* ============================================================ */

0 commit comments

Comments
 (0)