@@ -43,6 +43,55 @@ struct src_mctx_reset_on_exit {
4343 llama_memory_context_ptr * slot;
4444 ~src_mctx_reset_on_exit () { if (slot) slot->reset (); }
4545};
46+
47+ static void llama_assert_gemma4_mtp_source_placement (
48+ const llama_context * ctx,
49+ const llama_context * src) {
50+ if (!ctx || !src) {
51+ return ;
52+ }
53+
54+ const auto & model_dft = ctx->get_model ();
55+ const auto & model_tgt = src->get_model ();
56+
57+ if (model_dft.arch != LLM_ARCH_GEMMA4_ASSISTANT || model_tgt.arch != LLM_ARCH_GEMMA4) {
58+ return ;
59+ }
60+
61+ if (model_tgt.split_mode () == LLAMA_SPLIT_MODE_TENSOR) {
62+ return ;
63+ }
64+
65+ const auto & hparams_dft = model_dft.hparams ;
66+ const auto & hparams_tgt = model_tgt.hparams ;
67+
68+ const int32_t il_tgt_full = (int32_t ) hparams_tgt.n_layer - 1 ;
69+ const int32_t il_tgt_swa = (int32_t ) hparams_tgt.n_layer - 2 ;
70+
71+ ggml_backend_dev_t dev_cpu = ggml_backend_dev_by_type (GGML_BACKEND_DEVICE_TYPE_CPU);
72+ if (!dev_cpu) {
73+ throw std::runtime_error (" Gemma 4 assistant MTP placement check failed: no CPU backend found" );
74+ }
75+
76+ const bool kv_offload = src->get_cparams ().offload_kqv ;
77+
78+ for (uint32_t il_dft = 0 ; il_dft < hparams_dft.n_layer ; ++il_dft) {
79+ const int32_t il_tgt = hparams_dft.is_swa (il_dft) ? il_tgt_swa : il_tgt_full;
80+
81+ ggml_backend_dev_t dev_dft = model_dft.dev_layer (il_dft);
82+ ggml_backend_dev_t dev_kv = kv_offload ? model_tgt.dev_layer (il_tgt) : dev_cpu;
83+
84+ if (dev_dft != dev_kv) {
85+ throw std::runtime_error (format (
86+ " Gemma 4 assistant MTP placement mismatch: draft layer %d is on %s, "
87+ " but shared target KV layer %d is on %s" ,
88+ (int ) il_dft,
89+ ggml_backend_dev_name (dev_dft),
90+ (int ) il_tgt,
91+ ggml_backend_dev_name (dev_kv)));
92+ }
93+ }
94+ }
4695}
4796
4897llama_context::llama_context (
@@ -1144,6 +1193,7 @@ void llama_context::set_mtp_source(llama_context * src) {
11441193 if (src_ctx == src) {
11451194 return ;
11461195 }
1196+ llama_assert_gemma4_mtp_source_placement (this , src);
11471197 src_ctx = src;
11481198 src_mctx_for_decode.reset ();
11491199 // worst-case compute buffers were reserved without knowing about the source
0 commit comments