@@ -704,7 +704,7 @@ struct server_context_impl {
704704
705705 add_bos_token = llama_vocab_get_add_bos (vocab);
706706
707- if (params_base.speculative .type == COMMON_SPECULATIVE_TYPE_MTP || params_base. speculative . has_dft ()) {
707+ if (params_base.speculative .has_dft ()) {
708708 const auto & params_spec = params_base.speculative ;
709709
710710 auto params_dft = params_base;
@@ -724,23 +724,39 @@ struct server_context_impl {
724724 }
725725
726726 params_dft.tensor_buft_overrides = params_spec.tensor_buft_overrides ;
727- params_base.speculative .cparams_dft = common_context_params_to_llama (params_dft);
728727
729- if (params_base.speculative .requires_dft () && params_base.speculative .has_dft ()) {
730- SRV_INF (" loading draft model '%s'\n " , params_base.speculative .mparams_dft .path .c_str ());
728+ SRV_INF (" loading draft model '%s'\n " , params_base.speculative .mparams_dft .path .c_str ());
731729
732- params_dft. model = params_spec. mparams_dft ;
730+ auto mparams_dft = common_model_params_to_llama (params_dft) ;
733731
734- auto mparams_dft = common_model_params_to_llama (params_dft);
732+ model_dft.reset (llama_model_load_from_file (params_dft.model .path .c_str (), mparams_dft));
733+ if (model_dft == nullptr ) {
734+ SRV_ERR (" failed to load draft model, '%s'\n " , params_dft.model .path .c_str ());
735+ return false ;
736+ }
735737
736- model_dft.reset (llama_model_load_from_file (params_dft.model .path .c_str (), mparams_dft));
737- if (model_dft == nullptr ) {
738- SRV_ERR (" failed to load draft model, '%s'\n " , params_dft.model .path .c_str ());
739- return false ;
740- }
738+ params_base.speculative .model_dft = model_dft.get ();
739+ params_base.speculative .cparams_dft = common_context_params_to_llama (params_dft);
740+ } else if (params_base.speculative .type == COMMON_SPECULATIVE_TYPE_MTP ) {
741+ const auto & params_spec = params_base.speculative ;
742+
743+ auto params_dft = params_base;
744+
745+ params_dft.n_parallel = 1 ;
746+ params_dft.n_ctx = params_spec.n_ctx == 0 ? llama_n_ctx_seq (ctx) : params_spec.n_ctx ;
747+ params_dft.n_batch = llama_n_ctx_seq (ctx);
748+ params_dft.devices = params_spec.devices ;
749+ params_dft.n_gpu_layers = params_spec.n_gpu_layers ;
750+ params_dft.cache_type_k = params_spec.cache_type_k ;
751+ params_dft.cache_type_v = params_spec.cache_type_v ;
741752
742- params_base.speculative .model_dft = model_dft.get ();
753+ if (params_spec.cpuparams .n_threads > 0 ) {
754+ params_dft.cpuparams .n_threads = params_spec.cpuparams .n_threads ;
755+ params_dft.cpuparams_batch .n_threads = params_spec.cpuparams_batch .n_threads ;
743756 }
757+
758+ params_dft.tensor_buft_overrides = params_spec.tensor_buft_overrides ;
759+ params_base.speculative .cparams_dft = common_context_params_to_llama (params_dft);
744760 }
745761
746762 std::string & mmproj_path = params_base.mmproj .path ;
@@ -2162,10 +2178,16 @@ struct server_context_impl {
21622178
21632179 if (slot.task ->params .speculative .n_min > (int ) draft.size ()) {
21642180 SLT_DBG (slot, " ignoring small draft: %d < %d\n " , (int ) draft.size (), slot.task ->params .speculative .n_min );
2165- // fallback to normal decoding
2166- slot.i_batch = slot.i_batch_dft [0 ];
21672181 slot.drafted .clear ();
2168- slot.i_batch_dft .clear ();
2182+ if (slot.task ->params .speculative .type != COMMON_SPECULATIVE_TYPE_MTP ) {
2183+ // Non-MTP speculation can safely fall back to plain decoding.
2184+ slot.i_batch = slot.i_batch_dft [0 ];
2185+ slot.i_batch_dft .clear ();
2186+ } else {
2187+ // MTP still needs a 0-accept speculative round so accept() can stage
2188+ // the frontier hidden state for the next shifted first pass.
2189+ slot.i_batch = -1 ;
2190+ }
21692191 } else {
21702192 // keep track of total number of drafted tokens tested
21712193 slot.n_draft_total += draft.size ();
0 commit comments