Skip to content

Commit 5d5f1b4

Browse files
committed
fix: use rs for only MTP
1 parent 86d9f15 commit 5d5f1b4

4 files changed

Lines changed: 6 additions & 9 deletions

File tree

common/common.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1496,9 +1496,10 @@ struct llama_context_params common_context_params_to_llama(const common_params &
14961496
cparams.n_ctx = params.n_ctx;
14971497
cparams.n_seq_max = params.n_parallel;
14981498
{
1499-
const bool has_spec = (params.speculative.type != COMMON_SPECULATIVE_TYPE_NONE)
1500-
|| params.speculative.has_dft();
1501-
cparams.n_rs_seq = has_spec ? (uint32_t) params.speculative.draft.n_max : 0u;
1499+
// enable partial rollback only for MTP, each recurrent slot requires memory
1500+
// and MTP uses max 3-4 slots vs other techniques
1501+
const bool has_mtp_spec = params.speculative.type == COMMON_SPECULATIVE_TYPE_MTP;
1502+
cparams.n_rs_seq = has_mtp_spec ? (uint32_t) params.speculative.draft.n_max : 0u;
15021503
}
15031504
cparams.n_batch = params.n_batch;
15041505
cparams.n_ubatch = params.n_ubatch;

common/common.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -370,10 +370,6 @@ struct common_params_speculative {
370370
bool has_dft() const {
371371
return !draft.mparams.path.empty() || !draft.mparams.hf_repo.empty();
372372
}
373-
374-
bool has_mtp() const {
375-
return type == COMMON_SPECULATIVE_TYPE_MTP && mtp.model != nullptr;
376-
}
377373
};
378374

379375
struct common_params_vocoder {

common/speculative.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1152,7 +1152,7 @@ common_speculative * common_speculative_init(
11521152
}
11531153

11541154
llama_context * ctx_mtp = nullptr;
1155-
if (params.has_mtp()) {
1155+
if (params.type == COMMON_SPECULATIVE_TYPE_MTP) {
11561156
ctx_mtp = llama_init_from_model(params.mtp.model, params.mtp.cparams);
11571157
if (ctx_mtp == nullptr) {
11581158
LOG_ERR("%s", "failed to create MTP context\n");

tools/server/server-context.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -963,7 +963,7 @@ struct server_context_impl {
963963

964964
// try speculative decoding
965965
if (ctx_seq_rm_type != COMMON_CONTEXT_SEQ_RM_TYPE_NO) {
966-
slot.is_mtp_enabled = params_base.speculative.has_mtp();
966+
slot.is_mtp_enabled = params_base.speculative.type == COMMON_SPECULATIVE_TYPE_MTP;
967967
slot.spec.reset(common_speculative_init(params_base.speculative, slot.ctx));
968968

969969
if (slot.spec) {

0 commit comments

Comments
 (0)