File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -1496,9 +1496,10 @@ struct llama_context_params common_context_params_to_llama(const common_params &
14961496 cparams.n_ctx = params.n_ctx ;
14971497 cparams.n_seq_max = params.n_parallel ;
14981498 {
1499- const bool has_spec = (params.speculative .type != COMMON_SPECULATIVE_TYPE_NONE)
1500- || params.speculative .has_dft ();
1501- cparams.n_rs_seq = has_spec ? (uint32_t ) params.speculative .draft .n_max : 0u ;
1499+ // enable partial rollback only for MTP, each recurrent slot requires memory
1500+ // and MTP uses max 3-4 slots vs other techniques
1501+ const bool has_mtp_spec = params.speculative .type == COMMON_SPECULATIVE_TYPE_MTP;
1502+ cparams.n_rs_seq = has_mtp_spec ? (uint32_t ) params.speculative .draft .n_max : 0u ;
15021503 }
15031504 cparams.n_batch = params.n_batch ;
15041505 cparams.n_ubatch = params.n_ubatch ;
Original file line number Diff line number Diff line change @@ -370,10 +370,6 @@ struct common_params_speculative {
370370 bool has_dft () const {
371371 return !draft.mparams .path .empty () || !draft.mparams .hf_repo .empty ();
372372 }
373-
374- bool has_mtp () const {
375- return type == COMMON_SPECULATIVE_TYPE_MTP && mtp.model != nullptr ;
376- }
377373};
378374
379375struct common_params_vocoder {
Original file line number Diff line number Diff line change @@ -1152,7 +1152,7 @@ common_speculative * common_speculative_init(
11521152 }
11531153
11541154 llama_context * ctx_mtp = nullptr ;
1155- if (params.has_mtp () ) {
1155+ if (params.type == COMMON_SPECULATIVE_TYPE_MTP ) {
11561156 ctx_mtp = llama_init_from_model (params.mtp .model , params.mtp .cparams );
11571157 if (ctx_mtp == nullptr ) {
11581158 LOG_ERR (" %s" , " failed to create MTP context\n " );
Original file line number Diff line number Diff line change @@ -963,7 +963,7 @@ struct server_context_impl {
963963
964964 // try speculative decoding
965965 if (ctx_seq_rm_type != COMMON_CONTEXT_SEQ_RM_TYPE_NO) {
966- slot.is_mtp_enabled = params_base.speculative .has_mtp () ;
966+ slot.is_mtp_enabled = params_base.speculative .type == COMMON_SPECULATIVE_TYPE_MTP ;
967967 slot.spec .reset (common_speculative_init (params_base.speculative , slot.ctx ));
968968
969969 if (slot.spec ) {
You can’t perform that action at this time.
0 commit comments