Skip to content

Commit 00e8d49

Browse files
committed
Merge origin/feature/turboquant-kv-cache into b1-mtp-qwen-rebase
Brings in Gemma 4 + TurboQuant KV cache fixes: - fix/turbo-rope-shift-gemma4 (PR #10) - fix/iswa-get-can-shift-gemma4 (PR #9) - fix/mtp-assistant-tensor-prefix (PR #7)
2 parents d8e7cda + b1a7d71 commit 00e8d49

3 files changed

Lines changed: 16 additions & 3 deletions

File tree

src/llama-kv-cache-iswa.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -233,8 +233,7 @@ llama_memory_context_ptr llama_kv_cache_iswa::init_mtp(llama_seq_id seq_id, llam
233233

234234
bool llama_kv_cache_iswa::get_can_shift() const {
235235
return kv_base->get_can_shift() &&
236-
kv_swa->get_can_shift() &&
237-
kv_base->get_size() == kv_swa->get_size();
236+
kv_swa->get_can_shift();
238237
}
239238

240239
void llama_kv_cache_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {

src/llama-kv-cache.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1315,6 +1315,8 @@ uint32_t llama_kv_cache::get_n_stream() const {
13151315
}
13161316

13171317
bool llama_kv_cache::get_has_shift() const {
1318+
// TurboQuant uses kernel-level WHT rotation -- position shift is a no-op
1319+
if (!layers.empty() && (layers[0].k->type == GGML_TYPE_TURBO2_0 || layers[0].k->type == GGML_TYPE_TURBO3_0 || layers[0].k->type == GGML_TYPE_TURBO4_0)) { return false; }
13181320
bool result = false;
13191321

13201322
for (uint32_t s = 0; s < n_stream; ++s) {
@@ -2070,6 +2072,8 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co
20702072

20712073
for (const auto & layer : layers) {
20722074
const uint32_t il = layer.il;
2075+
const bool is_turbo_k = (layer.k->type == GGML_TYPE_TURBO2_0 || layer.k->type == GGML_TYPE_TURBO3_0 || layer.k->type == GGML_TYPE_TURBO4_0);
2076+
if (is_turbo_k) { continue; }
20732077

20742078
const int64_t n_head_kv = hparams.n_head_kv(il);
20752079
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);

src/llama.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1226,7 +1226,17 @@ int llama_model_load_mtp_from_file(struct llama_model * model, const char * path
12261226
llama_model_free(aux);
12271227
return -7;
12281228
}
1229-
1229+
// Rename all MTP assistant tensors with "mtp." prefix so they can be
1230+
// uniquely targeted by -ot rules without colliding with the main model's
1231+
// tensors. Tensors already prefixed with "mtp." (pre_projection,
1232+
// post_projection, centroids, token_ordering) are left unchanged.
1233+
for (auto & kv : aux->tensors_by_name) {
1234+
if (kv.first.substr(0, 4) != "mtp.") {
1235+
std::string new_name = "mtp." + kv.first;
1236+
ggml_set_name(kv.second, new_name.c_str());
1237+
kv.first = new_name;
1238+
}
1239+
}
12301240
tgt->mtp_assistant.reset(aux);
12311241
return 0;
12321242
}

0 commit comments

Comments
 (0)