@@ -22,6 +22,7 @@ const std::vector<enum common_speculative_type> common_speculative_types = {
2222 COMMON_SPECULATIVE_TYPE_NONE,
2323 COMMON_SPECULATIVE_TYPE_DRAFT,
2424 COMMON_SPECULATIVE_TYPE_EAGLE3,
25+ COMMON_SPECULATIVE_TYPE_MTP,
2526 COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE,
2627 COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K,
2728 COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V,
@@ -33,6 +34,7 @@ const std::map<std::string, enum common_speculative_type> common_speculative_typ
3334 {" none" , COMMON_SPECULATIVE_TYPE_NONE},
3435 {" draft" , COMMON_SPECULATIVE_TYPE_DRAFT},
3536 {" eagle3" , COMMON_SPECULATIVE_TYPE_EAGLE3},
37+ {" mtp" , COMMON_SPECULATIVE_TYPE_MTP},
3638 {" ngram_simple" , COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE},
3739 {" ngram_map_k" , COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K},
3840 {" ngram_map_k4v" , COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V},
@@ -599,6 +601,171 @@ struct common_speculative_state_eagle3 : public common_speculative_state {
599601 }
600602};
601603
604+ struct common_speculative_state_mtp : public common_speculative_state {
605+ llama_context * ctx_tgt = nullptr ;
606+ llama_context * ctx_mtp = nullptr ;
607+
608+ llama_batch batch; // single token draft step
609+ common_sampler * smpl = nullptr ;
610+ int32_t n_embd = 0 ;
611+
612+ uint16_t last_n_drafted = 0 ;
613+ int32_t last_n_accepted = -1 ;
614+
615+ common_speculative_state_mtp (enum common_speculative_type type,
616+ llama_context * ctx_tgt,
617+ llama_context * ctx_mtp)
618+ : common_speculative_state(type), ctx_tgt(ctx_tgt), ctx_mtp(ctx_mtp) {
619+ GGML_ASSERT (ctx_tgt && ctx_mtp);
620+ const llama_model * model_mtp = llama_get_model (ctx_mtp);
621+ n_embd = llama_model_n_embd (model_mtp);
622+
623+ {
624+ common_params_sampling sparams;
625+ sparams.no_perf = false ;
626+ sparams.top_k = 1 ;
627+ sparams.samplers = { COMMON_SAMPLER_TYPE_TOP_K };
628+ smpl = common_sampler_init (model_mtp, sparams);
629+ }
630+
631+ // TODO: multiple seq support
632+ batch = llama_batch_init (/* n_tokens=*/ 1 , /* embd=*/ n_embd, /* n_seq_max=*/ 1 );
633+ batch.token = (llama_token *) malloc (sizeof (llama_token));
634+ batch.n_tokens = 1 ;
635+ batch.n_seq_id [0 ] = 1 ;
636+ batch.seq_id [0 ][0 ] = 0 ;
637+ batch.logits [0 ] = 1 ;
638+
639+ llama_set_mtp (ctx_tgt, ctx_mtp);
640+ }
641+
642+ ~common_speculative_state_mtp () override {
643+ llama_set_mtp (ctx_tgt, nullptr );
644+ llama_batch_free (batch);
645+ common_sampler_free (smpl);
646+ if (ctx_mtp) {
647+ llama_free (ctx_mtp);
648+ }
649+ }
650+
651+ void begin (const llama_tokens & prompt) override {
652+ last_n_accepted = -1 ;
653+ last_n_drafted = 0 ;
654+
655+ const int32_t N = (int32_t ) prompt.size ();
656+ if (N <= 0 ) {
657+ return ;
658+ }
659+ const llama_pos pos_max = llama_memory_seq_pos_max (llama_get_memory (ctx_mtp), 0 );
660+ if (pos_max < N - 1 ) {
661+ LOG_WRN (" %s: ctx_mtp pos_max=%d < N-1=%d — "
662+ " streaming hook may not be registered or not all prefill rows "
663+ " have logits=true. Drafts may degrade.\n " ,
664+ __func__, (int ) pos_max, N - 1 );
665+ }
666+ }
667+
668+ void draft (
669+ const common_params_speculative & params,
670+ const llama_tokens & prompt_tgt,
671+ llama_token id_last,
672+ llama_tokens & draft_tokens) override {
673+ GGML_UNUSED (prompt_tgt);
674+ draft_tokens.clear ();
675+
676+ // accept with no-accepts (i.e. 0 accepts) returns early, but we still need to remove from the MTP kv-cache
677+ // TODO: check if bug in other spec states
678+ if (last_n_drafted > 0 ) {
679+ const int32_t n_to_drop = (int32_t ) last_n_drafted - 1 ;
680+ if (n_to_drop > 0 ) {
681+ const llama_pos pos_max = llama_memory_seq_pos_max (llama_get_memory (ctx_mtp), 0 );
682+ if (pos_max >= 0 ) {
683+ const llama_pos drop_from = pos_max - n_to_drop + 1 ;
684+ llama_memory_seq_rm (llama_get_memory (ctx_mtp), 0 , drop_from, -1 );
685+ }
686+ }
687+ last_n_drafted = 0 ;
688+ last_n_accepted = 0 ;
689+ }
690+
691+ const int32_t n_max = std::max (1 , params.draft .n_max );
692+ const size_t row_bytes = (size_t ) n_embd * sizeof (float );
693+
694+ llama_token cond_tok = id_last;
695+ llama_pos pos = llama_memory_seq_pos_max (llama_get_memory (ctx_mtp), 0 ) + 1 ;
696+
697+ // auto-regressive loop for MTP
698+ for (int32_t k = 0 ; k < n_max; ++k) {
699+ ggml_tensor * src;
700+ int32_t src_row;
701+ if (k == 0 ) {
702+ src = llama_context_get_t_h_pre_norm (ctx_tgt);
703+ if (last_n_accepted < 0 ) {
704+ // First draft after begin(): trunk's most recent decode is
705+ // the last prefill ubatch; its last row is h_{N-1}.
706+ src_row = (src && src->ne [1 ] > 0 ) ? (int32_t ) src->ne [1 ] - 1 : 0 ;
707+ } else {
708+ src_row = last_n_accepted;
709+ }
710+ llama_synchronize (ctx_tgt);
711+ } else {
712+ // for the AR path get the mtp_out from the mtp ctx
713+ src = llama_context_get_t_mtp_out (ctx_mtp);
714+ src_row = src ? (int32_t ) src->ne [1 ] - 1 : 0 ;
715+ llama_synchronize (ctx_mtp);
716+ }
717+ if (!src) {
718+ LOG_WRN (" %s: missing source tensor at k=%d; stopping chain\n " , __func__, k);
719+ return ;
720+ }
721+ ggml_backend_tensor_get (src, batch.embd ,
722+ (size_t ) src_row * row_bytes, row_bytes);
723+
724+ batch.token [0 ] = cond_tok;
725+ batch.pos [0 ] = pos;
726+
727+ const int32_t dec_rc = llama_decode (ctx_mtp, batch);
728+ if (dec_rc != 0 ) {
729+ LOG_DBG (" %s: llama_decode rc=%d at k=%d; stopping chain\n " , __func__, dec_rc, k);
730+ return ;
731+ }
732+
733+ const llama_token best = common_sampler_sample (smpl, ctx_mtp, 0 );
734+ common_sampler_accept (smpl, best, /* accept_grammar=*/ false );
735+ draft_tokens.push_back (best);
736+ cond_tok = best;
737+ ++pos;
738+ }
739+
740+ last_n_drafted = (uint16_t ) draft_tokens.size ();
741+ }
742+
743+ void accept (uint16_t n_accepted) override {
744+ const llama_pos pos_max = llama_memory_seq_pos_max (llama_get_memory (ctx_mtp), 0 );
745+ const int32_t n_drafted_last = (int32_t ) last_n_drafted;
746+ const int32_t n_to_drop = std::max (0 , n_drafted_last - (int32_t ) n_accepted - 1 );
747+ if (pos_max < 0 ) {
748+ last_n_accepted = (int32_t ) n_accepted;
749+ return ;
750+ }
751+ if (n_to_drop > 0 ) {
752+ const llama_pos drop_from = pos_max - n_to_drop + 1 ;
753+ llama_memory_seq_rm (llama_get_memory (ctx_mtp), /* seq_id=*/ 0 ,
754+ /* p0=*/ drop_from, /* p1=*/ -1 );
755+ }
756+ last_n_drafted = 0 ;
757+ last_n_accepted = (int32_t ) n_accepted;
758+ }
759+
760+ int32_t n_max (const common_params_speculative & params) const override {
761+ return std::max (1 , params.draft .n_max );
762+ }
763+
764+ int32_t n_min (const common_params_speculative & params) const override {
765+ return std::max (1 , params.draft .n_min );
766+ }
767+ };
768+
602769// state of self-speculation (simple implementation, not ngram-map)
603770struct common_speculative_state_ngram_simple : public common_speculative_state {
604771 common_ngram_simple_config config;
@@ -952,6 +1119,7 @@ std::string common_speculative_type_to_str(enum common_speculative_type type) {
9521119 case COMMON_SPECULATIVE_TYPE_NONE: return " none" ;
9531120 case COMMON_SPECULATIVE_TYPE_DRAFT: return " draft" ;
9541121 case COMMON_SPECULATIVE_TYPE_EAGLE3: return " eagle3" ;
1122+ case COMMON_SPECULATIVE_TYPE_MTP: return " mtp" ;
9551123 case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: return " ngram_simple" ;
9561124 case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: return " ngram_map_k" ;
9571125 case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: return " ngram_map_k4v" ;
@@ -983,11 +1151,24 @@ common_speculative * common_speculative_init(
9831151 }
9841152 }
9851153
1154+ llama_context * ctx_mtp = nullptr ;
1155+ if (params.has_mtp ()) {
1156+ ctx_mtp = llama_init_from_model (params.mtp .model , params.mtp .cparams );
1157+ if (ctx_mtp == nullptr ) {
1158+ LOG_ERR (" %s" , " failed to create MTP context\n " );
1159+ if (ctx_dft) {
1160+ llama_free (ctx_dft);
1161+ }
1162+ return nullptr ;
1163+ }
1164+ }
1165+
9861166 // Compute the implementations to use based on the config and their order of preference
9871167 std::vector<common_speculative_config> configs = {}; // list of speculative configs to try
9881168 {
9891169 bool has_draft = !params.draft .mparams .path .empty ();
9901170 bool has_draft_eagle3 = false ; // TODO PR-18039: if params.speculative.eagle3
1171+ bool has_mtp = (params.type == COMMON_SPECULATIVE_TYPE_MTP) && (ctx_mtp != nullptr );
9911172
9921173 bool has_ngram_cache = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_CACHE);
9931174 bool has_ngram_simple = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE);
@@ -1034,6 +1215,9 @@ common_speculative * common_speculative_init(
10341215 if (has_draft_eagle3) {
10351216 configs.push_back (common_speculative_config (COMMON_SPECULATIVE_TYPE_EAGLE3, params));
10361217 }
1218+ if (has_mtp) {
1219+ configs.push_back (common_speculative_config (COMMON_SPECULATIVE_TYPE_MTP, params));
1220+ }
10371221 }
10381222
10391223 std::vector<std::unique_ptr<common_speculative_state>> impls = {};
@@ -1058,6 +1242,11 @@ common_speculative * common_speculative_init(
10581242 impls.push_back (std::make_unique<common_speculative_state_eagle3>(config.type ));
10591243 break ;
10601244 }
1245+ case COMMON_SPECULATIVE_TYPE_MTP: {
1246+ impls.push_back (std::make_unique<common_speculative_state_mtp>(
1247+ config.type , ctx_tgt, ctx_mtp));
1248+ break ;
1249+ }
10611250 case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: {
10621251 common_ngram_map ngram_map = get_common_ngram_map (config.type , config.params .ngram_simple );
10631252
0 commit comments