@@ -416,6 +416,12 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
416416
417417 int32_t n_embd = 0 ;
418418
419+ // Stacked-MTP draft chain length. Each MTP block k is trained to take the
420+ // (k-1)-th block's hidden state and predict the (k+1)-th token ahead, so
421+ // we can draft at most `n_mtp_layers` tokens per round before going
422+ // out-of-distribution.
423+ int32_t n_mtp_layers = 0 ;
424+
419425 // Per-sequence cross-batch carryover: pair (h_p, x_{p+1}) at MTP pos p+1.
420426 // The last h-row of one process() call needs the first token of the NEXT
421427 // call to pair with, so it's stashed here until that next call fires.
@@ -442,7 +448,9 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
442448 auto * ctx_dft = this ->params .ctx_dft ;
443449 GGML_ASSERT (ctx_tgt && ctx_dft && " MTP requires ctx_tgt and ctx_dft to be set" );
444450
445- n_embd = llama_model_n_embd (llama_get_model (ctx_dft));
451+ n_embd = llama_model_n_embd (llama_get_model (ctx_dft));
452+ n_mtp_layers = llama_model_n_nextn (llama_get_model (ctx_dft));
453+ GGML_ASSERT (n_mtp_layers > 0 && " MTP draft requires the draft model to declare nextn_predict_layers > 0" );
446454
447455 LOG_INF (" %s: adding speculative implementation 'draft-mtp'\n " , __func__);
448456 LOG_INF (" %s: - n_max=%d, n_min=%d, p_min=%.2f, n_embd=%d\n " , __func__, this ->params .n_max , this ->params .n_min , this ->params .p_min , n_embd);
@@ -635,9 +643,10 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
635643 std::memcpy (batch.embd + n_embd*(batch.n_tokens - 1 ), h_row, row_bytes);
636644 }
637645
638- // First draft step uses the first MTP block (step 0). Archs with a
639- // single MTP block ignore this; multi-block archs (Step-3.5-Flash) use
640- // it to round-robin across their N MTP layers.
646+ // Stacked-MTP is a *chain*: block k consumes block (k-1)'s hidden
647+ // state and predicts one token further than block (k-1). Step k uses
648+ // MTP block k; we cannot exceed `n_mtp_layers` steps without going
649+ // out of training distribution.
641650 llama_set_mtp_step (ctx_dft, 0 );
642651
643652 int ret = llama_decode (ctx_dft, batch);
@@ -648,6 +657,10 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
648657
649658 int i = 0 ;
650659
660+ // Cap draft depth: never run more MTP chain steps than there are
661+ // trained MTP blocks. `params.n_max` may be larger; we just stop.
662+ const int n_chain_max = n_mtp_layers;
663+
651664 while (n_drafting > 0 ) {
652665 int i_batch = 0 ;
653666
@@ -704,8 +717,16 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
704717 break ;
705718 }
706719
707- // Step i+1: feed the i-th sampled draft token into the (i+1)-th
708- // MTP block. Multi-block archs round-robin via mtp_step % N.
720+ // We just sampled the (i+1)-th draft token (T_{i+1}). If the
721+ // chain has no further block (T_{i+2} would need MTP block i+1
722+ // which doesn't exist), stop — sampling, not decoding, is what
723+ // matters for the drafted result.
724+ if (i + 1 >= n_chain_max) {
725+ break ;
726+ }
727+
728+ // Step i+1: feed the i-th sampled draft token into MTP block i+1.
729+ // Direct indexing, no modulo — the chain is bounded by n_mtp_layers.
709730 llama_set_mtp_step (ctx_dft, (uint32_t )(i + 1 ));
710731
711732 // evaluate the drafted tokens on the draft model
0 commit comments