@@ -22,6 +22,7 @@ const std::vector<enum common_speculative_type> common_speculative_types = {
2222 COMMON_SPECULATIVE_TYPE_NONE,
2323 COMMON_SPECULATIVE_TYPE_DRAFT,
2424 COMMON_SPECULATIVE_TYPE_EAGLE3,
25+ COMMON_SPECULATIVE_TYPE_MTP,
2526 COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE,
2627 COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K,
2728 COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V,
@@ -33,6 +34,7 @@ const std::map<std::string, enum common_speculative_type> common_speculative_typ
3334 {" none" , COMMON_SPECULATIVE_TYPE_NONE},
3435 {" draft" , COMMON_SPECULATIVE_TYPE_DRAFT},
3536 {" eagle3" , COMMON_SPECULATIVE_TYPE_EAGLE3},
37+ {" mtp" , COMMON_SPECULATIVE_TYPE_MTP},
3638 {" ngram_simple" , COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE},
3739 {" ngram_map_k" , COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K},
3840 {" ngram_map_k4v" , COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V},
@@ -642,6 +644,171 @@ struct common_speculative_state_eagle3 : public common_speculative_state {
642644 }
643645};
644646
647+ struct common_speculative_state_mtp : public common_speculative_state {
648+ llama_context * ctx_tgt = nullptr ;
649+ llama_context * ctx_mtp = nullptr ;
650+
651+ llama_batch batch; // single token draft step
652+ common_sampler * smpl = nullptr ;
653+ int32_t n_embd = 0 ;
654+
655+ uint16_t last_n_drafted = 0 ;
656+ int32_t last_n_accepted = -1 ;
657+
658+ common_speculative_state_mtp (enum common_speculative_type type,
659+ llama_context * ctx_tgt,
660+ llama_context * ctx_mtp)
661+ : common_speculative_state(type), ctx_tgt(ctx_tgt), ctx_mtp(ctx_mtp) {
662+ GGML_ASSERT (ctx_tgt && ctx_mtp);
663+ const llama_model * model_mtp = llama_get_model (ctx_mtp);
664+ n_embd = llama_model_n_embd (model_mtp);
665+
666+ {
667+ common_params_sampling sparams;
668+ sparams.no_perf = false ;
669+ sparams.top_k = 1 ;
670+ sparams.samplers = { COMMON_SAMPLER_TYPE_TOP_K };
671+ smpl = common_sampler_init (model_mtp, sparams);
672+ }
673+
674+ // TODO: multiple seq support
675+ batch = llama_batch_init (/* n_tokens=*/ 1 , /* embd=*/ n_embd, /* n_seq_max=*/ 1 );
676+ batch.token = (llama_token *) malloc (sizeof (llama_token));
677+ batch.n_tokens = 1 ;
678+ batch.n_seq_id [0 ] = 1 ;
679+ batch.seq_id [0 ][0 ] = 0 ;
680+ batch.logits [0 ] = 1 ;
681+
682+ llama_set_mtp (ctx_tgt, ctx_mtp);
683+ }
684+
685+ ~common_speculative_state_mtp () override {
686+ llama_set_mtp (ctx_tgt, nullptr );
687+ llama_batch_free (batch);
688+ common_sampler_free (smpl);
689+ if (ctx_mtp) {
690+ llama_free (ctx_mtp);
691+ }
692+ }
693+
694+ void begin (const llama_tokens & prompt) override {
695+ last_n_accepted = -1 ;
696+ last_n_drafted = 0 ;
697+
698+ const int32_t N = (int32_t ) prompt.size ();
699+ if (N <= 0 ) {
700+ return ;
701+ }
702+ const llama_pos pos_max = llama_memory_seq_pos_max (llama_get_memory (ctx_mtp), 0 );
703+ if (pos_max < N - 1 ) {
704+ LOG_WRN (" %s: ctx_mtp pos_max=%d < N-1=%d — "
705+ " streaming hook may not be registered or not all prefill rows "
706+ " have logits=true. Drafts may degrade.\n " ,
707+ __func__, (int ) pos_max, N - 1 );
708+ }
709+ }
710+
711+ void draft (
712+ const common_params_speculative & params,
713+ const llama_tokens & prompt_tgt,
714+ llama_token id_last,
715+ llama_tokens & draft_tokens) override {
716+ GGML_UNUSED (prompt_tgt);
717+ draft_tokens.clear ();
718+
719+ // accept with no-accepts (i.e. 0 accepts) returns early, but we still need to remove from the MTP kv-cache
720+ // TODO: check if bug in other spec states
721+ if (last_n_drafted > 0 ) {
722+ const int32_t n_to_drop = (int32_t ) last_n_drafted - 1 ;
723+ if (n_to_drop > 0 ) {
724+ const llama_pos pos_max = llama_memory_seq_pos_max (llama_get_memory (ctx_mtp), 0 );
725+ if (pos_max >= 0 ) {
726+ const llama_pos drop_from = pos_max - n_to_drop + 1 ;
727+ llama_memory_seq_rm (llama_get_memory (ctx_mtp), 0 , drop_from, -1 );
728+ }
729+ }
730+ last_n_drafted = 0 ;
731+ last_n_accepted = 0 ;
732+ }
733+
734+ const int32_t n_max = std::max (1 , params.draft .n_max );
735+ const size_t row_bytes = (size_t ) n_embd * sizeof (float );
736+
737+ llama_token cond_tok = id_last;
738+ llama_pos pos = llama_memory_seq_pos_max (llama_get_memory (ctx_mtp), 0 ) + 1 ;
739+
740+ // auto-regressive loop for MTP
741+ for (int32_t k = 0 ; k < n_max; ++k) {
742+ ggml_tensor * src;
743+ int32_t src_row;
744+ if (k == 0 ) {
745+ src = llama_context_get_t_h_pre_norm (ctx_tgt);
746+ if (last_n_accepted < 0 ) {
747+ // First draft after begin(): trunk's most recent decode is
748+ // the last prefill ubatch; its last row is h_{N-1}.
749+ src_row = (src && src->ne [1 ] > 0 ) ? (int32_t ) src->ne [1 ] - 1 : 0 ;
750+ } else {
751+ src_row = last_n_accepted;
752+ }
753+ llama_synchronize (ctx_tgt);
754+ } else {
755+ // for the AR path get the mtp_out from the mtp ctx
756+ src = llama_context_get_t_mtp_out (ctx_mtp);
757+ src_row = src ? (int32_t ) src->ne [1 ] - 1 : 0 ;
758+ llama_synchronize (ctx_mtp);
759+ }
760+ if (!src) {
761+ LOG_WRN (" %s: missing source tensor at k=%d; stopping chain\n " , __func__, k);
762+ return ;
763+ }
764+ ggml_backend_tensor_get (src, batch.embd ,
765+ (size_t ) src_row * row_bytes, row_bytes);
766+
767+ batch.token [0 ] = cond_tok;
768+ batch.pos [0 ] = pos;
769+
770+ const int32_t dec_rc = llama_decode (ctx_mtp, batch);
771+ if (dec_rc != 0 ) {
772+ LOG_DBG (" %s: llama_decode rc=%d at k=%d; stopping chain\n " , __func__, dec_rc, k);
773+ return ;
774+ }
775+
776+ const llama_token best = common_sampler_sample (smpl, ctx_mtp, 0 );
777+ common_sampler_accept (smpl, best, /* accept_grammar=*/ false );
778+ draft_tokens.push_back (best);
779+ cond_tok = best;
780+ ++pos;
781+ }
782+
783+ last_n_drafted = (uint16_t ) draft_tokens.size ();
784+ }
785+
786+ void accept (uint16_t n_accepted) override {
787+ const llama_pos pos_max = llama_memory_seq_pos_max (llama_get_memory (ctx_mtp), 0 );
788+ const int32_t n_drafted_last = (int32_t ) last_n_drafted;
789+ const int32_t n_to_drop = std::max (0 , n_drafted_last - (int32_t ) n_accepted - 1 );
790+ if (pos_max < 0 ) {
791+ last_n_accepted = (int32_t ) n_accepted;
792+ return ;
793+ }
794+ if (n_to_drop > 0 ) {
795+ const llama_pos drop_from = pos_max - n_to_drop + 1 ;
796+ llama_memory_seq_rm (llama_get_memory (ctx_mtp), /* seq_id=*/ 0 ,
797+ /* p0=*/ drop_from, /* p1=*/ -1 );
798+ }
799+ last_n_drafted = 0 ;
800+ last_n_accepted = (int32_t ) n_accepted;
801+ }
802+
803+ int32_t n_max (const common_params_speculative & params) const override {
804+ return std::max (1 , params.draft .n_max );
805+ }
806+
807+ int32_t n_min (const common_params_speculative & params) const override {
808+ return std::max (1 , params.draft .n_min );
809+ }
810+ };
811+
645812// state of self-speculation (simple implementation, not ngram-map)
646813struct common_speculative_state_ngram_simple : public common_speculative_state {
647814 common_ngram_simple_config config;
@@ -995,6 +1162,7 @@ std::string common_speculative_type_to_str(enum common_speculative_type type) {
9951162 case COMMON_SPECULATIVE_TYPE_NONE: return " none" ;
9961163 case COMMON_SPECULATIVE_TYPE_DRAFT: return " draft" ;
9971164 case COMMON_SPECULATIVE_TYPE_EAGLE3: return " eagle3" ;
1165+ case COMMON_SPECULATIVE_TYPE_MTP: return " mtp" ;
9981166 case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: return " ngram_simple" ;
9991167 case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: return " ngram_map_k" ;
10001168 case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: return " ngram_map_k4v" ;
@@ -1026,11 +1194,24 @@ common_speculative * common_speculative_init(
10261194 }
10271195 }
10281196
1197+ llama_context * ctx_mtp = nullptr ;
1198+ if (params.type == COMMON_SPECULATIVE_TYPE_MTP) {
1199+ ctx_mtp = llama_init_from_model (params.mtp .model , params.mtp .cparams );
1200+ if (ctx_mtp == nullptr ) {
1201+ LOG_ERR (" %s" , " failed to create MTP context\n " );
1202+ if (ctx_dft) {
1203+ llama_free (ctx_dft);
1204+ }
1205+ return nullptr ;
1206+ }
1207+ }
1208+
10291209 // Compute the implementations to use based on the config and their order of preference
10301210 std::vector<common_speculative_config> configs = {}; // list of speculative configs to try
10311211 {
10321212 bool has_draft = !params.draft .mparams .path .empty ();
10331213 bool has_draft_eagle3 = false ; // TODO PR-18039: if params.speculative.eagle3
1214+ bool has_mtp = (params.type == COMMON_SPECULATIVE_TYPE_MTP) && (ctx_mtp != nullptr );
10341215
10351216 bool has_ngram_cache = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_CACHE);
10361217 bool has_ngram_simple = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE);
@@ -1077,6 +1258,9 @@ common_speculative * common_speculative_init(
10771258 if (has_draft_eagle3) {
10781259 configs.push_back (common_speculative_config (COMMON_SPECULATIVE_TYPE_EAGLE3, params));
10791260 }
1261+ if (has_mtp) {
1262+ configs.push_back (common_speculative_config (COMMON_SPECULATIVE_TYPE_MTP, params));
1263+ }
10801264 }
10811265
10821266 std::vector<std::unique_ptr<common_speculative_state>> impls = {};
@@ -1101,6 +1285,11 @@ common_speculative * common_speculative_init(
11011285 impls.push_back (std::make_unique<common_speculative_state_eagle3>(config.type ));
11021286 break ;
11031287 }
1288+ case COMMON_SPECULATIVE_TYPE_MTP: {
1289+ impls.push_back (std::make_unique<common_speculative_state_mtp>(
1290+ config.type , ctx_tgt, ctx_mtp));
1291+ break ;
1292+ }
11041293 case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: {
11051294 common_ngram_map ngram_map = get_common_ngram_map (config.type , config.params .ngram_simple );
11061295
0 commit comments