Skip to content

Commit 3c4e521

Browse files
committed
add assert that draft + shared kv should be on same device
1 parent c93c750 commit 3c4e521

1 file changed

Lines changed: 50 additions & 0 deletions

File tree

src/llama-context.cpp

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,55 @@ struct src_mctx_reset_on_exit {
4343
llama_memory_context_ptr * slot;
4444
~src_mctx_reset_on_exit() { if (slot) slot->reset(); }
4545
};
46+
47+
static void llama_assert_gemma4_mtp_source_placement(
48+
const llama_context * ctx,
49+
const llama_context * src) {
50+
if (!ctx || !src) {
51+
return;
52+
}
53+
54+
const auto & model_dft = ctx->get_model();
55+
const auto & model_tgt = src->get_model();
56+
57+
if (model_dft.arch != LLM_ARCH_GEMMA4_ASSISTANT || model_tgt.arch != LLM_ARCH_GEMMA4) {
58+
return;
59+
}
60+
61+
if (model_tgt.split_mode() == LLAMA_SPLIT_MODE_TENSOR) {
62+
return;
63+
}
64+
65+
const auto & hparams_dft = model_dft.hparams;
66+
const auto & hparams_tgt = model_tgt.hparams;
67+
68+
const int32_t il_tgt_full = (int32_t) hparams_tgt.n_layer - 1;
69+
const int32_t il_tgt_swa = (int32_t) hparams_tgt.n_layer - 2;
70+
71+
ggml_backend_dev_t dev_cpu = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
72+
if (!dev_cpu) {
73+
throw std::runtime_error("Gemma 4 assistant MTP placement check failed: no CPU backend found");
74+
}
75+
76+
const bool kv_offload = src->get_cparams().offload_kqv;
77+
78+
for (uint32_t il_dft = 0; il_dft < hparams_dft.n_layer; ++il_dft) {
79+
const int32_t il_tgt = hparams_dft.is_swa(il_dft) ? il_tgt_swa : il_tgt_full;
80+
81+
ggml_backend_dev_t dev_dft = model_dft.dev_layer(il_dft);
82+
ggml_backend_dev_t dev_kv = kv_offload ? model_tgt.dev_layer(il_tgt) : dev_cpu;
83+
84+
if (dev_dft != dev_kv) {
85+
throw std::runtime_error(format(
86+
"Gemma 4 assistant MTP placement mismatch: draft layer %d is on %s, "
87+
"but shared target KV layer %d is on %s",
88+
(int) il_dft,
89+
ggml_backend_dev_name(dev_dft),
90+
(int) il_tgt,
91+
ggml_backend_dev_name(dev_kv)));
92+
}
93+
}
94+
}
4695
}
4796

4897
llama_context::llama_context(
@@ -1144,6 +1193,7 @@ void llama_context::set_mtp_source(llama_context * src) {
11441193
if (src_ctx == src) {
11451194
return;
11461195
}
1196+
llama_assert_gemma4_mtp_source_placement(this, src);
11471197
src_ctx = src;
11481198
src_mctx_for_decode.reset();
11491199
// worst-case compute buffers were reserved without knowing about the source

0 commit comments

Comments
 (0)