@@ -2204,45 +2204,45 @@ struct server_context_impl {
22042204
22052205 if (spec) {
22062206 common_speculative_get_draft_params (spec.get (), slot.id ).drafting = false ;
2207- }
22082207
2209- const bool use_ckpt_tgt = ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL ;
2210- const bool use_ckpt_dft = ctx_dft_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL ;
2208+ const bool use_ckpt_tgt = ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL ;
2209+ const bool use_ckpt_dft = ctx_dft_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL ;
22112210
2212- const int n_draft_max = slot.get_n_draft_max ();
2211+ const int n_draft_max = slot.get_n_draft_max ();
22132212
2214- if (n_draft_max > 0 ) {
2215- GGML_ASSERT (slot.can_speculate ());
2213+ if (n_draft_max > 0 ) {
2214+ GGML_ASSERT (slot.can_speculate ());
22162215
2217- if (!slot.spec_draft .empty ()) {
2218- // we have a previous (partial) draft to reuse
2219- if (use_ckpt_tgt) {
2220- GGML_ASSERT (!slot.spec_ckpt .empty ());
2221- }
2222- } else {
2223- GGML_ASSERT (slot.spec_i_batch .empty ());
2216+ if (!slot.spec_draft .empty ()) {
2217+ // we have a previous (partial) draft to reuse
2218+ if (use_ckpt_tgt) {
2219+ GGML_ASSERT (!slot.spec_ckpt .empty ());
2220+ }
2221+ } else {
2222+ GGML_ASSERT (slot.spec_i_batch .empty ());
22242223
2225- slot.spec_ckpt .update_pos (
2226- slot.prompt .n_tokens (),
2227- llama_memory_seq_pos_min (llama_get_memory (ctx_tgt), slot.id ),
2228- llama_memory_seq_pos_max (llama_get_memory (ctx_tgt), slot.id ));
2224+ slot.spec_ckpt .update_pos (
2225+ slot.prompt .n_tokens (),
2226+ llama_memory_seq_pos_min (llama_get_memory (ctx_tgt), slot.id ),
2227+ llama_memory_seq_pos_max (llama_get_memory (ctx_tgt), slot.id ));
22292228
2230- if (use_ckpt_dft) {
2231- slot.spec_ckpt .update_dft (ctx_dft.get (), slot.id , LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE );
2232- }
2229+ if (use_ckpt_dft) {
2230+ slot.spec_ckpt .update_dft (ctx_dft.get (), slot.id , LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE );
2231+ }
22332232
2234- slot.spec_prompt = slot.prompt .tokens .get_text_tokens ();
2233+ slot.spec_prompt = slot.prompt .tokens .get_text_tokens ();
22352234
2236- common_speculative_get_draft_params (spec.get (), slot.id ) = {
2237- /* .drafting = */ true ,
2238- /* .n_max = */ n_draft_max,
2239- /* .n_past = */ slot.prompt .n_tokens (),
2240- /* .id_last = */ slot.sampled ,
2241- /* .prompt = */ &slot.spec_prompt ,
2242- /* .result = */ &slot.spec_draft ,
2243- };
2235+ common_speculative_get_draft_params (spec.get (), slot.id ) = {
2236+ /* .drafting = */ true ,
2237+ /* .n_max = */ n_draft_max,
2238+ /* .n_past = */ slot.prompt .n_tokens (),
2239+ /* .id_last = */ slot.sampled ,
2240+ /* .prompt = */ &slot.spec_prompt ,
2241+ /* .result = */ &slot.spec_draft ,
2242+ };
22442243
2245- drafting.push_back (&slot);
2244+ drafting.push_back (&slot);
2245+ }
22462246 }
22472247 }
22482248 }
@@ -2256,29 +2256,33 @@ struct server_context_impl {
22562256 for (auto * slot_ptr : drafting) {
22572257 auto & slot = *slot_ptr;
22582258
2259- slot.n_draft_total += slot.spec_draft .size ();
2259+ auto & draft = slot.spec_draft ;
2260+ auto & ckpt = slot.spec_ckpt ;
2261+
2262+ slot.n_draft_total += draft.size ();
22602263
22612264 // TODO: avoid restoring the draft context and re-evaluating the drafted tokens when not needed [TAG_SPEC_AVOID_DRAFT_REEVAL]
22622265 if (ctx_dft) {
2263- slot. spec_ckpt .load_dft (ctx_dft.get (), slot.id , LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE );
2266+ ckpt .load_dft (ctx_dft.get (), slot.id , LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE );
22642267
2265- llama_memory_seq_rm (llama_get_memory (ctx_dft.get ()), slot.id , slot. spec_ckpt .pos_max + 1 , -1 );
2268+ llama_memory_seq_rm (llama_get_memory (ctx_dft.get ()), slot.id , ckpt .pos_max + 1 , -1 );
22662269 }
22672270
2268- const bool use_ckpt_tgt = ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL ;
2271+ if (!draft.empty ()) {
2272+ const bool use_ckpt_tgt = ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL ;
22692273
2270- if (!slot.spec_draft .empty ()) {
22712274 if (use_ckpt_tgt) {
22722275 // const int64_t t_start = ggml_time_us();
22732276
2274- slot. spec_ckpt .update_tgt (ctx_tgt, slot.id , LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE );
2277+ ckpt .update_tgt (ctx_tgt, slot.id , LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE );
22752278
22762279 // const int64_t t_total = ggml_time_us() - t_start;
22772280 // printf("checkpoint total: %f ms\n", t_total / 1000.0);
22782281
22792282 SLT_DBG (slot, " created speculative checkpoint (pos_min = %d, pos_max = %d, n_tokens = %d, size = %.3f MiB, draft = %.3f MiB)\n " ,
2280- slot.spec_ckpt .pos_min , slot.spec_ckpt .pos_max , slot.prompt .n_tokens (),
2281- (float ) slot.spec_ckpt .size () / 1024 / 1024 , (float ) slot.spec_ckpt .data_dft .size () / 1024 / 1024 );
2283+ ckpt.pos_min , ckpt.pos_max , slot.prompt .n_tokens (),
2284+ (float ) ckpt.size () / 1024 / 1024 ,
2285+ (float ) ckpt.data_dft .size () / 1024 / 1024 );
22822286 }
22832287 }
22842288 }
0 commit comments