Skip to content

Commit 9eed2dc

Browse files
unamedkrclaude
andcommitted
feat(loader): MLA metadata capture for deepseek2 (Phase 2.1)
Add tq_model_config_t fields for MLA: is_mla, kv_lora_rank, qk_rope_head_dim, qk_nope_head_dim, v_head_dim. Loader detects arch=deepseek2 with attn_kv_a_mqa + attn_kv_b tensors and reads the GGUF metadata keys (attention.kv_lora_rank, attention.key_length, attention.value_length, rope.dimension_count) to populate them. Logs the architectural KV compression at load time: MLA — kv_lora_rank=512, key_length=192 (rope=64 + nope=128), v_head_dim=128 (KV cache compression 5120→576 = 8.9x vs standard) That stacks with our turbo_kv_4b 8x for ~71x total compression — the moat for 256K context on 16 GB once Phase 2.2+ lands the forward-pass MLA decompression. Forward pass still emits the loud Phase 1 warning. Phase 2.1 is strictly metadata; weight pointers and attention compute are TBD. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 1c85bdc commit 9eed2dc

2 files changed

Lines changed: 49 additions & 0 deletions

File tree

include/turboquant/tq_engine.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,25 @@ typedef struct {
7474
/* Phi-3 fused-tensor flags — drive state buffer sizing */
7575
int has_fused_qkv; /* any layer has gguf_w_qkv */
7676
int has_fused_up_gate; /* any layer has gguf_w_up_gate */
77+
78+
/* MLA (Multi-head Latent Attention) — DeepSeek V2/V3, Coder-V2.
79+
* 0 = standard attention; 1 = MLA. When set, the standard wq/wk/wv
80+
* pointers are NULL; instead each layer uses gguf_w_q,
81+
* gguf_w_kv_a_mqa, gguf_w_kv_b plus an attn_kv_a_norm vector.
82+
*
83+
* Q has its own RoPE/no-RoPE split: head_dim = qk_nope_head_dim +
84+
* qk_rope_head_dim. V uses v_head_dim (typically 128). The KV cache
85+
* stores only the latent (kv_lora_rank dims) plus a single shared
86+
* RoPE-K of qk_rope_head_dim — total per-token KV is
87+
* (kv_lora_rank + qk_rope_head_dim) instead of
88+
* (n_heads * (key_dim + v_dim)). For DeepSeek-V2-Lite that is
89+
* 576 vs 6144 dims, a 10.7× architectural compression that
90+
* stacks with our turbo_kv_4b 8× for ~85× total. */
91+
int is_mla;
92+
int kv_lora_rank; /* latent dim, e.g., 512 */
93+
int qk_rope_head_dim; /* per-head RoPE dim, e.g., 64 */
94+
int qk_nope_head_dim; /* per-head no-RoPE dim, e.g., 128 */
95+
int v_head_dim; /* per-head value dim, e.g., 128 */
7796
} tq_model_config_t;
7897

7998
/* ============================================================

src/engine/tq_model.c

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3047,6 +3047,36 @@ tq_model_t* tq_load_gguf(const char* path) {
30473047
}
30483048
}
30493049

3050+
/* MLA detection (deepseek2): no attn_k, but attn_kv_a_mqa + attn_kv_b
3051+
* exist. Capture kv_lora_rank, qk_rope/nope_head_dim, v_head_dim from
3052+
* GGUF metadata so the future Phase 2 forward pass can size buffers
3053+
* and dispatch correctly. The forward path itself is still TBD —
3054+
* loader-level recording only. */
3055+
if (strcmp(gguf->arch, "deepseek2") == 0) {
3056+
const tq_gguf_tensor_t* kv_a = tq_gguf_find_tensor(gguf, "blk.0.attn_kv_a_mqa.weight");
3057+
const tq_gguf_tensor_t* kv_b = tq_gguf_find_tensor(gguf, "blk.0.attn_kv_b.weight");
3058+
if (kv_a && kv_b) {
3059+
c->is_mla = 1;
3060+
c->kv_lora_rank = tq_gguf_get_i32(gguf, GGUF_KEY("attention.kv_lora_rank"), 512);
3061+
int key_length = tq_gguf_get_i32(gguf, GGUF_KEY("attention.key_length"), 192);
3062+
c->v_head_dim = tq_gguf_get_i32(gguf, GGUF_KEY("attention.value_length"), 128);
3063+
c->qk_rope_head_dim = tq_gguf_get_i32(gguf, GGUF_KEY("rope.dimension_count"), 64);
3064+
c->qk_nope_head_dim = key_length - c->qk_rope_head_dim;
3065+
/* Override head_dim to the MLA key length (used for Q proj sizing) */
3066+
c->head_dim = key_length;
3067+
fprintf(stderr,
3068+
"tq_load_gguf: MLA — kv_lora_rank=%d, key_length=%d "
3069+
"(rope=%d + nope=%d), v_head_dim=%d "
3070+
"(KV cache compression %d→%d = %.1fx vs standard)\n",
3071+
c->kv_lora_rank, key_length,
3072+
c->qk_rope_head_dim, c->qk_nope_head_dim, c->v_head_dim,
3073+
c->n_heads * (key_length + c->v_head_dim),
3074+
c->kv_lora_rank + c->qk_rope_head_dim,
3075+
(double)(c->n_heads * (key_length + c->v_head_dim)) /
3076+
(double)(c->kv_lora_rank + c->qk_rope_head_dim));
3077+
}
3078+
}
3079+
30503080
/* MoE configuration */
30513081
c->num_experts = tq_gguf_get_i32(gguf, GGUF_KEY("expert_count"), 0);
30523082
c->num_active_experts = tq_gguf_get_i32(gguf, GGUF_KEY("expert_used_count"), 0);

0 commit comments

Comments
 (0)