@@ -716,7 +716,7 @@ struct server_context_impl {
716716
717717 add_bos_token = llama_vocab_get_add_bos (vocab);
718718
719- if (params_base.speculative .type == COMMON_SPECULATIVE_TYPE_MTP || params_base. speculative . has_dft ()) {
719+ if (params_base.speculative .has_dft ()) {
720720 const auto & params_spec = params_base.speculative ;
721721
722722 auto params_dft = params_base;
@@ -736,23 +736,39 @@ struct server_context_impl {
736736 }
737737
738738 params_dft.tensor_buft_overrides = params_spec.tensor_buft_overrides ;
739- params_base.speculative .cparams_dft = common_context_params_to_llama (params_dft);
740739
741- if (params_base.speculative .requires_dft () && params_base.speculative .has_dft ()) {
742- SRV_INF (" loading draft model '%s'\n " , params_base.speculative .mparams_dft .path .c_str ());
740+ SRV_INF (" loading draft model '%s'\n " , params_base.speculative .mparams_dft .path .c_str ());
743741
744- params_dft. model = params_spec. mparams_dft ;
742+ auto mparams_dft = common_model_params_to_llama (params_dft) ;
745743
746- auto mparams_dft = common_model_params_to_llama (params_dft);
744+ model_dft.reset (llama_model_load_from_file (params_dft.model .path .c_str (), mparams_dft));
745+ if (model_dft == nullptr ) {
746+ SRV_ERR (" failed to load draft model, '%s'\n " , params_dft.model .path .c_str ());
747+ return false ;
748+ }
747749
748- model_dft.reset (llama_model_load_from_file (params_dft.model .path .c_str (), mparams_dft));
749- if (model_dft == nullptr ) {
750- SRV_ERR (" failed to load draft model, '%s'\n " , params_dft.model .path .c_str ());
751- return false ;
752- }
750+ params_base.speculative .model_dft = model_dft.get ();
751+ params_base.speculative .cparams_dft = common_context_params_to_llama (params_dft);
752+ } else if (params_base.speculative .type == COMMON_SPECULATIVE_TYPE_MTP ) {
753+ const auto & params_spec = params_base.speculative ;
754+
755+ auto params_dft = params_base;
756+
757+ params_dft.n_parallel = 1 ;
758+ params_dft.n_ctx = params_spec.n_ctx == 0 ? llama_n_ctx_seq (ctx) : params_spec.n_ctx ;
759+ params_dft.n_batch = llama_n_ctx_seq (ctx);
760+ params_dft.devices = params_spec.devices ;
761+ params_dft.n_gpu_layers = params_spec.n_gpu_layers ;
762+ params_dft.cache_type_k = params_spec.cache_type_k ;
763+ params_dft.cache_type_v = params_spec.cache_type_v ;
753764
754- params_base.speculative .model_dft = model_dft.get ();
765+ if (params_spec.cpuparams .n_threads > 0 ) {
766+ params_dft.cpuparams .n_threads = params_spec.cpuparams .n_threads ;
767+ params_dft.cpuparams_batch .n_threads = params_spec.cpuparams_batch .n_threads ;
755768 }
769+
770+ params_dft.tensor_buft_overrides = params_spec.tensor_buft_overrides ;
771+ params_base.speculative .cparams_dft = common_context_params_to_llama (params_dft);
756772 }
757773
758774 std::string & mmproj_path = params_base.mmproj .path ;
@@ -2196,10 +2212,16 @@ struct server_context_impl {
21962212
21972213 if (slot.task ->params .speculative .n_min > (int ) draft.size ()) {
21982214 SLT_DBG (slot, " ignoring small draft: %d < %d\n " , (int ) draft.size (), slot.task ->params .speculative .n_min );
2199- // fallback to normal decoding
2200- slot.i_batch = slot.i_batch_dft [0 ];
22012215 slot.drafted .clear ();
2202- slot.i_batch_dft .clear ();
2216+ if (slot.task ->params .speculative .type != COMMON_SPECULATIVE_TYPE_MTP ) {
2217+ // Non-MTP speculation can safely fall back to plain decoding.
2218+ slot.i_batch = slot.i_batch_dft [0 ];
2219+ slot.i_batch_dft .clear ();
2220+ } else {
2221+ // MTP still needs a 0-accept speculative round so accept() can stage
2222+ // the frontier hidden state for the next shifted first pass.
2223+ slot.i_batch = -1 ;
2224+ }
22032225 } else {
22042226 // keep track of total number of drafted tokens tested
22052227 slot.n_draft_total += draft.size ();
0 commit comments