@@ -113,7 +113,7 @@ void llama_model_step35::load_arch_tensors(llama_model_loader & ml) {
113113 layer.ffn_down_shexp = create_tensor (tn (LLM_TENSOR_FFN_DOWN_SHEXP , " weight" , i), {hparams.n_ff_shexp , n_embd}, TENSOR_NOT_REQUIRED );
114114 };
115115
116- auto load_block_mtp = [&](int i) {
116+ auto load_block_mtp = [&](int i, bool is_first_mtp ) {
117117 auto & layer = layers[i];
118118
119119 const uint32_t n_head_l = hparams.n_head (i);
@@ -123,7 +123,14 @@ void llama_model_step35::load_arch_tensors(llama_model_loader & ml) {
123123 // The MTP block is a full Step3p5 decoder layer (mtp_block) plus the
124124 // NextN-specific wiring (enorm/hnorm/eh_proj + optional shared head).
125125 // `mtp_flags` becomes NOT_REQUIRED when the GGUF is trunk-only.
126- layer.attn_norm = create_tensor (tn (LLM_TENSOR_ATTN_NORM , " weight" , i), {n_embd}, mtp_flags);
126+ //
127+ // Only the FIRST MTP block (i == n_main) is required for the
128+ // single-block MTP runtime; trailing MTP blocks are always tolerated
129+ // as missing so pruned GGUFs (block 0 only) load cleanly. Override
130+ // mtp_flags to NOT_REQUIRED for those.
131+ const int eff_mtp_flags = is_first_mtp ? mtp_flags : (mtp_flags | TENSOR_NOT_REQUIRED );
132+
133+ layer.attn_norm = create_tensor (tn (LLM_TENSOR_ATTN_NORM , " weight" , i), {n_embd}, eff_mtp_flags);
127134 layer.attn_q_norm = create_tensor (tn (LLM_TENSOR_ATTN_Q_NORM , " weight" , i), {n_embd_head_k}, TENSOR_NOT_REQUIRED );
128135 layer.attn_k_norm = create_tensor (tn (LLM_TENSOR_ATTN_K_NORM , " weight" , i), {n_embd_head_k}, TENSOR_NOT_REQUIRED );
129136
@@ -134,12 +141,12 @@ void llama_model_step35::load_arch_tensors(llama_model_loader & ml) {
134141 layer.rope_freqs = create_tensor (tn (LLM_TENSOR_ROPE_FREQS , " weight" , i), {n_rot_max/2 }, TENSOR_NOT_REQUIRED | TENSOR_DUPLICATED );
135142 }
136143
137- create_tensor_qkv (layer, i, n_embd, n_embd_head_k * n_head_l, n_embd_k_gqa, n_embd_v_gqa, mtp_flags );
138- layer.wo = create_tensor (tn (LLM_TENSOR_ATTN_OUT , " weight" , i), {n_embd_head_v * n_head_l, n_embd}, mtp_flags );
144+ create_tensor_qkv (layer, i, n_embd, n_embd_head_k * n_head_l, n_embd_k_gqa, n_embd_v_gqa, eff_mtp_flags );
145+ layer.wo = create_tensor (tn (LLM_TENSOR_ATTN_OUT , " weight" , i), {n_embd_head_v * n_head_l, n_embd}, eff_mtp_flags );
139146
140147 layer.wqkv_gate = create_tensor (tn (LLM_TENSOR_ATTN_GATE , " weight" , i), {n_embd, n_head_l}, TENSOR_NOT_REQUIRED );
141148
142- layer.ffn_norm = create_tensor (tn (LLM_TENSOR_FFN_NORM , " weight" , i), {n_embd}, mtp_flags );
149+ layer.ffn_norm = create_tensor (tn (LLM_TENSOR_FFN_NORM , " weight" , i), {n_embd}, eff_mtp_flags );
143150
144151 // dense MLP (leading dense blocks) — present if the MTP block isn't MoE
145152 layer.ffn_gate = create_tensor (tn (LLM_TENSOR_FFN_GATE , " weight" , i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED );
@@ -159,9 +166,9 @@ void llama_model_step35::load_arch_tensors(llama_model_loader & ml) {
159166 layer.ffn_down_shexp = create_tensor (tn (LLM_TENSOR_FFN_DOWN_SHEXP , " weight" , i), {hparams.n_ff_shexp , n_embd}, TENSOR_NOT_REQUIRED );
160167
161168 // NextN-specific tensors that define the MTP block.
162- layer.nextn .eh_proj = create_tensor (tn (LLM_TENSOR_NEXTN_EH_PROJ , " weight" , i), { 2 * n_embd, n_embd }, mtp_flags );
163- layer.nextn .enorm = create_tensor (tn (LLM_TENSOR_NEXTN_ENORM , " weight" , i), { n_embd }, mtp_flags );
164- layer.nextn .hnorm = create_tensor (tn (LLM_TENSOR_NEXTN_HNORM , " weight" , i), { n_embd }, mtp_flags );
169+ layer.nextn .eh_proj = create_tensor (tn (LLM_TENSOR_NEXTN_EH_PROJ , " weight" , i), { 2 * n_embd, n_embd }, eff_mtp_flags );
170+ layer.nextn .enorm = create_tensor (tn (LLM_TENSOR_NEXTN_ENORM , " weight" , i), { n_embd }, eff_mtp_flags );
171+ layer.nextn .hnorm = create_tensor (tn (LLM_TENSOR_NEXTN_HNORM , " weight" , i), { n_embd }, eff_mtp_flags );
165172 layer.nextn .embed_tokens = create_tensor (tn (LLM_TENSOR_NEXTN_EMBED_TOKENS , " weight" , i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED );
166173 layer.nextn .shared_head_head = create_tensor (tn (LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD , " weight" , i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED );
167174 layer.nextn .shared_head_norm = create_tensor (tn (LLM_TENSOR_NEXTN_SHARED_HEAD_NORM , " weight" , i), { n_embd }, TENSOR_NOT_REQUIRED );
@@ -170,8 +177,13 @@ void llama_model_step35::load_arch_tensors(llama_model_loader & ml) {
170177 for (int i = 0 ; i < (int ) n_main; ++i) {
171178 load_block_trunk (i, trunk_flags);
172179 }
180+ // Only the first MTP block (i == n_main) is required at runtime — the
181+ // single-block-MTP graph in build_arch_graph always uses that one.
182+ // Trailing MTP blocks are loaded if present (so an un-pruned GGUF with
183+ // all MTP layers still works) but tolerated when absent via the pruning
184+ // path. See scripts/prune_step35_extra_mtp.py for the pruner.
173185 for (int i = (int ) n_main; i < n_layer; ++i) {
174- load_block_mtp (i);
186+ load_block_mtp (i, /* is_first_mtp= */ i == ( int ) n_main );
175187 }
176188}
177189
@@ -362,13 +374,13 @@ llama_model_step35::graph_mtp::graph_mtp(const llama_model & model, const llm_gr
362374 : llm_graph_context(params) {
363375 GGML_ASSERT (hparams.nextn_predict_layers > 0 && " STEP35 MTP requires nextn_predict_layers > 0" );
364376
365- // Round-robin across MTP blocks at draft step boundaries. Matches vllm's
366- // `current_step_idx = spec_step_idx % num_mtp_layers` (step3p5_mtp.py).
367- // The first MTP block lives at layer index `n_main` ; the speculative
368- // driver bumps `cparams.mtp_step` between AR iterations.
369- const int n_main = ( int ) hparams. n_layer - ( int ) hparams. nextn_predict_layers ;
370- const int step_offset = ( int ) (cparams. mtp_step % hparams. nextn_predict_layers );
371- const int il = n_main + step_offset ;
377+ // Single-block MTP only: always run the first trained MTP block (Qwen
378+ // MTP / vLLM single-MTP-layer style). Multi-block round-robin proved to
379+ // be a much deeper refactor than this PR justifies ; the trailing MTP
380+ // blocks are loaded with TENSOR_NOT_REQUIRED so pruned GGUFs (with just
381+ // block 0) also work — see load_arch_tensors below and
382+ // scripts/prune_step35_extra_mtp.py.
383+ const int il = ( int ) hparams. n_layer - ( int ) hparams. nextn_predict_layers ;
372384 const auto & layer = model.layers [il];
373385
374386 GGML_ASSERT (layer.nextn .eh_proj && " MTP block missing nextn.eh_proj" );
0 commit comments