@@ -78,9 +78,10 @@ enum server_state {
7878struct server_slot {
7979 int id;
8080
81- // TODO: change to unique_ptrs for consistency:
8281 llama_context * ctx = nullptr ;
8382
83+ common_context_seq_rm_type ctx_seq_rm_type = COMMON_CONTEXT_SEQ_RM_TYPE_NO;
84+
8485 // multimodal
8586 mtmd_context * mctx = nullptr ;
8687
@@ -90,7 +91,6 @@ struct server_slot {
9091 server_prompt_checkpoint spec_ckpt;
9192 common_speculative_ptr spec;
9293
93-
9494 // TODO: move members that belong to the task (such as `generated_text`, `has_new_line`) to task_results_state
9595 // see https://github.com/ggml-org/llama.cpp/pull/18283#issuecomment-3710175837
9696 std::unique_ptr<const server_task> task;
@@ -343,7 +343,7 @@ struct server_slot {
343343
344344 if (!spec_draft.empty ()) {
345345 // we have a previous (partial) draft to reuse
346- if (task-> params . speculative . use_checkpoints ) {
346+ if (ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL ) {
347347 GGML_ASSERT (!spec_ckpt.empty ());
348348 }
349349 } else {
@@ -362,15 +362,13 @@ struct server_slot {
362362 spec_draft.clear ();
363363 }
364364
365- if (!spec_draft.empty () && params_spec. use_checkpoints ) {
365+ if (!spec_draft.empty () && ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL ) {
366366 const auto n_tokens = prompt.tokens .size ();
367367
368- auto & ckpt = spec_ckpt;
369-
370- ckpt = server_get_checkpoint (ctx, this ->id , n_tokens);
368+ spec_ckpt = server_get_checkpoint (ctx, this ->id , n_tokens);
371369
372370 SLT_DBG (*this , " created speculative checkpoint (pos_min = %d, pos_max = %d, n_tokens = %zu, size = %.3f MiB)\n " ,
373- ckpt .pos_min , ckpt .pos_max , n_tokens, (float ) ckpt .data .size () / 1024 / 1024 );
371+ spec_ckpt .pos_min , spec_ckpt .pos_max , n_tokens, (float ) spec_ckpt .data .size () / 1024 / 1024 );
374372 }
375373 }
376374
@@ -871,14 +869,13 @@ struct server_context_impl {
871869
872870 slots.clear ();
873871
874- const auto spec_type = common_speculative_is_compat (ctx);
875- if (spec_type == COMMON_SPECULATIVE_COMPAT_TYPE_NO ) {
872+ const auto ctx_seq_rm_type = common_context_can_seq_rm (ctx);
873+ if (ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_NO ) {
876874 SRV_WRN (" %s" , " speculative decoding not supported by this context\n " );
877875 }
878876
879- if (spec_type == COMMON_SPECULATIVE_COMPAT_TYPE_CKPT ) {
877+ if (ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL ) {
880878 SRV_WRN (" %s" , " speculative decoding will use checkpoints\n " );
881- params_base.speculative .use_checkpoints = true ;
882879 }
883880
884881 // initialize slots
@@ -893,11 +890,13 @@ struct server_context_impl {
893890 slot.ctx = ctx;
894891 slot.n_ctx = n_ctx_slot;
895892
893+ slot.ctx_seq_rm_type = ctx_seq_rm_type;
894+
896895 slot.mctx = mctx;
897896 slot.prompt .tokens .has_mtmd = mctx != nullptr ;
898897
899898 // try speculative decoding
900- if (spec_type != COMMON_SPECULATIVE_COMPAT_TYPE_NO ) {
899+ if (ctx_seq_rm_type != COMMON_CONTEXT_SEQ_RM_TYPE_NO ) {
901900 slot.spec .reset (common_speculative_init (params_base.speculative , slot.ctx ));
902901
903902 if (slot.spec ) {
@@ -2588,15 +2587,11 @@ struct server_context_impl {
25882587
25892588 // make a checkpoint of the parts of the memory that cannot be rolled back.
25902589 // checkpoints are created only if:
2590+ // - the model does not support partial sequence removal
25912591 // - the model uses SWA and we are not using `swa_full`
2592- // - the model architecture is marked as recurrent or hybrid
2593- //
2594- // TODO: try to make this conditional on the context or the memory module, instead of the model type
25952592 do_checkpoint = do_checkpoint && (
2596- llama_model_is_recurrent (model) ||
2597- llama_model_is_hybrid (model) ||
2598- (llama_model_n_swa (model) > 0 && !params_base.swa_full )
2599- );
2593+ (slot.ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) ||
2594+ (llama_model_n_swa (model) > 0 && !params_base.swa_full ));
26002595
26012596 bool has_mtmd = false ;
26022597
@@ -2965,8 +2960,6 @@ struct server_context_impl {
29652960
29662961 // verify and try to accept the draft
29672962 {
2968- const auto & params_spec = slot.task ->params .speculative ;
2969-
29702963 common_sampler_ptr smpl_save (common_sampler_clone (slot.smpl .get ()));
29712964
29722965 GGML_ASSERT (slot.spec_i_batch .size () == n_draft + 1 );
@@ -2979,13 +2972,14 @@ struct server_context_impl {
29792972
29802973 // check for partial draft acceptance
29812974 if (accepted.size () < slot.spec_draft .size () + 1 ) {
2982- if (params_spec. use_checkpoints ) {
2975+ if (slot. ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL ) {
29832976 // partial acceptance is not supported by the context -> truncate the draft and restore the state
29842977 slot.spec_draft = std::move (accepted);
29852978
2986- auto & ckpt = slot.spec_ckpt ;
2979+ const auto & ckpt = slot.spec_ckpt ;
29872980
2988- SLT_DBG (slot, " restoring speculative checkpoint (pos_min = %d, pos_max = %d, size = %zu)\n " , ckpt.pos_min , ckpt.pos_max , ckpt.size ());
2981+ SLT_DBG (slot, " restoring speculative checkpoint (pos_min = %d, pos_max = %d, size = %zu)\n " ,
2982+ ckpt.pos_min , ckpt.pos_max , ckpt.size ());
29892983
29902984 const size_t n = llama_state_seq_set_data_ext (slot.ctx , ckpt.data .data (), ckpt.size (), slot.id , LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
29912985 if (n != ckpt.size ()) {
0 commit comments