@@ -22,6 +22,7 @@ const std::vector<enum common_speculative_type> common_speculative_types = {
2222 COMMON_SPECULATIVE_TYPE_NONE,
2323 COMMON_SPECULATIVE_TYPE_DRAFT,
2424 COMMON_SPECULATIVE_TYPE_EAGLE3,
25+ COMMON_SPECULATIVE_TYPE_MTP,
2526 COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE,
2627 COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K,
2728 COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V,
@@ -33,6 +34,7 @@ const std::map<std::string, enum common_speculative_type> common_speculative_typ
3334 {" none" , COMMON_SPECULATIVE_TYPE_NONE},
3435 {" draft" , COMMON_SPECULATIVE_TYPE_DRAFT},
3536 {" eagle3" , COMMON_SPECULATIVE_TYPE_EAGLE3},
37+ {" mtp" , COMMON_SPECULATIVE_TYPE_MTP},
3638 {" ngram_simple" , COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE},
3739 {" ngram_map_k" , COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K},
3840 {" ngram_map_k4v" , COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V},
@@ -608,6 +610,171 @@ struct common_speculative_state_eagle3 : public common_speculative_state {
608610 }
609611};
610612
613+ struct common_speculative_state_mtp : public common_speculative_state {
614+ llama_context * ctx_tgt = nullptr ;
615+ llama_context * ctx_mtp = nullptr ;
616+
617+ llama_batch batch; // single token draft step
618+ common_sampler * smpl = nullptr ;
619+ int32_t n_embd = 0 ;
620+
621+ uint16_t last_n_drafted = 0 ;
622+ int32_t last_n_accepted = -1 ;
623+
624+ common_speculative_state_mtp (enum common_speculative_type type,
625+ llama_context * ctx_tgt,
626+ llama_context * ctx_mtp)
627+ : common_speculative_state(type), ctx_tgt(ctx_tgt), ctx_mtp(ctx_mtp) {
628+ GGML_ASSERT (ctx_tgt && ctx_mtp);
629+ const llama_model * model_mtp = llama_get_model (ctx_mtp);
630+ n_embd = llama_model_n_embd (model_mtp);
631+
632+ {
633+ common_params_sampling sparams;
634+ sparams.no_perf = false ;
635+ sparams.top_k = 1 ;
636+ sparams.samplers = { COMMON_SAMPLER_TYPE_TOP_K };
637+ smpl = common_sampler_init (model_mtp, sparams);
638+ }
639+
640+ // TODO: multiple seq support
641+ batch = llama_batch_init (/* n_tokens=*/ 1 , /* embd=*/ n_embd, /* n_seq_max=*/ 1 );
642+ batch.token = (llama_token *) malloc (sizeof (llama_token));
643+ batch.n_tokens = 1 ;
644+ batch.n_seq_id [0 ] = 1 ;
645+ batch.seq_id [0 ][0 ] = 0 ;
646+ batch.logits [0 ] = 1 ;
647+
648+ llama_set_mtp (ctx_tgt, ctx_mtp);
649+ }
650+
651+ ~common_speculative_state_mtp () override {
652+ llama_set_mtp (ctx_tgt, nullptr );
653+ llama_batch_free (batch);
654+ common_sampler_free (smpl);
655+ if (ctx_mtp) {
656+ llama_free (ctx_mtp);
657+ }
658+ }
659+
660+ void begin (const llama_tokens & prompt) override {
661+ last_n_accepted = -1 ;
662+ last_n_drafted = 0 ;
663+
664+ const int32_t N = (int32_t ) prompt.size ();
665+ if (N <= 0 ) {
666+ return ;
667+ }
668+ const llama_pos pos_max = llama_memory_seq_pos_max (llama_get_memory (ctx_mtp), 0 );
669+ if (pos_max < N - 1 ) {
670+ LOG_WRN (" %s: ctx_mtp pos_max=%d < N-1=%d — "
671+ " streaming hook may not be registered or not all prefill rows "
672+ " have logits=true. Drafts may degrade.\n " ,
673+ __func__, (int ) pos_max, N - 1 );
674+ }
675+ }
676+
677+ void draft (
678+ const common_params_speculative & params,
679+ const llama_tokens & prompt_tgt,
680+ llama_token id_last,
681+ llama_tokens & draft_tokens) override {
682+ GGML_UNUSED (prompt_tgt);
683+ draft_tokens.clear ();
684+
685+ // accept with no-accepts (i.e. 0 accepts) returns early, but we still need to remove from the MTP kv-cache
686+ // TODO: check if bug in other spec states
687+ if (last_n_drafted > 0 ) {
688+ const int32_t n_to_drop = (int32_t ) last_n_drafted - 1 ;
689+ if (n_to_drop > 0 ) {
690+ const llama_pos pos_max = llama_memory_seq_pos_max (llama_get_memory (ctx_mtp), 0 );
691+ if (pos_max >= 0 ) {
692+ const llama_pos drop_from = pos_max - n_to_drop + 1 ;
693+ llama_memory_seq_rm (llama_get_memory (ctx_mtp), 0 , drop_from, -1 );
694+ }
695+ }
696+ last_n_drafted = 0 ;
697+ last_n_accepted = 0 ;
698+ }
699+
700+ const int32_t n_max = std::max (1 , params.draft .n_max );
701+ const size_t row_bytes = (size_t ) n_embd * sizeof (float );
702+
703+ llama_token cond_tok = id_last;
704+ llama_pos pos = llama_memory_seq_pos_max (llama_get_memory (ctx_mtp), 0 ) + 1 ;
705+
706+ // auto-regressive loop for MTP
707+ for (int32_t k = 0 ; k < n_max; ++k) {
708+ ggml_tensor * src;
709+ int32_t src_row;
710+ if (k == 0 ) {
711+ src = llama_context_get_t_h_pre_norm (ctx_tgt);
712+ if (last_n_accepted < 0 ) {
713+ // First draft after begin(): trunk's most recent decode is
714+ // the last prefill ubatch; its last row is h_{N-1}.
715+ src_row = (src && src->ne [1 ] > 0 ) ? (int32_t ) src->ne [1 ] - 1 : 0 ;
716+ } else {
717+ src_row = last_n_accepted;
718+ }
719+ llama_synchronize (ctx_tgt);
720+ } else {
721+ // for the AR path get the mtp_out from the mtp ctx
722+ src = llama_context_get_t_mtp_out (ctx_mtp);
723+ src_row = src ? (int32_t ) src->ne [1 ] - 1 : 0 ;
724+ llama_synchronize (ctx_mtp);
725+ }
726+ if (!src) {
727+ LOG_WRN (" %s: missing source tensor at k=%d; stopping chain\n " , __func__, k);
728+ return ;
729+ }
730+ ggml_backend_tensor_get (src, batch.embd ,
731+ (size_t ) src_row * row_bytes, row_bytes);
732+
733+ batch.token [0 ] = cond_tok;
734+ batch.pos [0 ] = pos;
735+
736+ const int32_t dec_rc = llama_decode (ctx_mtp, batch);
737+ if (dec_rc != 0 ) {
738+ LOG_DBG (" %s: llama_decode rc=%d at k=%d; stopping chain\n " , __func__, dec_rc, k);
739+ return ;
740+ }
741+
742+ const llama_token best = common_sampler_sample (smpl, ctx_mtp, 0 );
743+ common_sampler_accept (smpl, best, /* accept_grammar=*/ false );
744+ draft_tokens.push_back (best);
745+ cond_tok = best;
746+ ++pos;
747+ }
748+
749+ last_n_drafted = (uint16_t ) draft_tokens.size ();
750+ }
751+
752+ void accept (uint16_t n_accepted) override {
753+ const llama_pos pos_max = llama_memory_seq_pos_max (llama_get_memory (ctx_mtp), 0 );
754+ const int32_t n_drafted_last = (int32_t ) last_n_drafted;
755+ const int32_t n_to_drop = std::max (0 , n_drafted_last - (int32_t ) n_accepted - 1 );
756+ if (pos_max < 0 ) {
757+ last_n_accepted = (int32_t ) n_accepted;
758+ return ;
759+ }
760+ if (n_to_drop > 0 ) {
761+ const llama_pos drop_from = pos_max - n_to_drop + 1 ;
762+ llama_memory_seq_rm (llama_get_memory (ctx_mtp), /* seq_id=*/ 0 ,
763+ /* p0=*/ drop_from, /* p1=*/ -1 );
764+ }
765+ last_n_drafted = 0 ;
766+ last_n_accepted = (int32_t ) n_accepted;
767+ }
768+
769+ int32_t n_max (const common_params_speculative & params) const override {
770+ return std::max (1 , params.draft .n_max );
771+ }
772+
773+ int32_t n_min (const common_params_speculative & params) const override {
774+ return std::max (1 , params.draft .n_min );
775+ }
776+ };
777+
611778// state of self-speculation (simple implementation, not ngram-map)
612779struct common_speculative_state_ngram_simple : public common_speculative_state {
613780 common_ngram_simple_config config;
@@ -963,6 +1130,7 @@ std::string common_speculative_type_to_str(enum common_speculative_type type) {
9631130 case COMMON_SPECULATIVE_TYPE_NONE: return " none" ;
9641131 case COMMON_SPECULATIVE_TYPE_DRAFT: return " draft" ;
9651132 case COMMON_SPECULATIVE_TYPE_EAGLE3: return " eagle3" ;
1133+ case COMMON_SPECULATIVE_TYPE_MTP: return " mtp" ;
9661134 case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: return " ngram_simple" ;
9671135 case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: return " ngram_map_k" ;
9681136 case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: return " ngram_map_k4v" ;
@@ -994,11 +1162,24 @@ common_speculative * common_speculative_init(
9941162 }
9951163 }
9961164
1165+ llama_context * ctx_mtp = nullptr ;
1166+ if (params.has_mtp ()) {
1167+ ctx_mtp = llama_init_from_model (params.mtp .model , params.mtp .cparams );
1168+ if (ctx_mtp == nullptr ) {
1169+ LOG_ERR (" %s" , " failed to create MTP context\n " );
1170+ if (ctx_dft) {
1171+ llama_free (ctx_dft);
1172+ }
1173+ return nullptr ;
1174+ }
1175+ }
1176+
9971177 // Compute the implementations to use based on the config and their order of preference
9981178 std::vector<common_speculative_config> configs = {}; // list of speculative configs to try
9991179 {
10001180 bool has_draft = !params.draft .mparams .path .empty ();
10011181 bool has_draft_eagle3 = false ; // TODO PR-18039: if params.speculative.eagle3
1182+ bool has_mtp = (params.type == COMMON_SPECULATIVE_TYPE_MTP) && (ctx_mtp != nullptr );
10021183
10031184 bool has_ngram_cache = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_CACHE);
10041185 bool has_ngram_simple = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE);
@@ -1045,6 +1226,9 @@ common_speculative * common_speculative_init(
10451226 if (has_draft_eagle3) {
10461227 configs.push_back (common_speculative_config (COMMON_SPECULATIVE_TYPE_EAGLE3, params));
10471228 }
1229+ if (has_mtp) {
1230+ configs.push_back (common_speculative_config (COMMON_SPECULATIVE_TYPE_MTP, params));
1231+ }
10481232 }
10491233
10501234 std::vector<std::unique_ptr<common_speculative_state>> impls = {};
@@ -1069,6 +1253,11 @@ common_speculative * common_speculative_init(
10691253 impls.push_back (std::make_unique<common_speculative_state_eagle3>(config.type ));
10701254 break ;
10711255 }
1256+ case COMMON_SPECULATIVE_TYPE_MTP: {
1257+ impls.push_back (std::make_unique<common_speculative_state_mtp>(
1258+ config.type , ctx_tgt, ctx_mtp));
1259+ break ;
1260+ }
10721261 case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: {
10731262 common_ngram_map ngram_map = get_common_ngram_map (config.type , config.params .ngram_simple );
10741263
0 commit comments