Skip to content

Commit 7c158fb

Browse files
authored
server : disable on-device spec checkpoints (ggml-org#24108)
1 parent 260862b commit 7c158fb

2 files changed

Lines changed: 11 additions & 11 deletions

File tree

examples/speculative-simple/speculative-simple.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ int main(int argc, char ** argv) {
175175
llama_memory_seq_pos_max(llama_get_memory(ctx_tgt), seq_id));
176176

177177
if (use_ckpt_dft) {
178-
ckpt.update_dft(ctx_dft.get(), seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
178+
ckpt.update_dft(ctx_dft.get(), seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
179179
}
180180

181181
// generate a new draft
@@ -196,12 +196,12 @@ int main(int argc, char ** argv) {
196196
// this allows us to restore the state if partial draft acceptance occurs
197197
if (!draft.empty()) {
198198
if (use_ckpt_tgt) {
199-
ckpt.update_tgt(ctx_tgt, seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
199+
ckpt.update_tgt(ctx_tgt, seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
200200
}
201201
}
202202

203203
{
204-
ckpt.load_dft(ctx_dft.get(), seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
204+
ckpt.load_dft(ctx_dft.get(), seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
205205

206206
llama_memory_seq_rm(llama_get_memory(ctx_dft.get()), seq_id, ckpt.pos_max + 1, -1);
207207
}
@@ -261,13 +261,13 @@ int main(int argc, char ** argv) {
261261
draft = std::move(ids);
262262

263263
{
264-
ckpt.load_tgt(ctx_tgt, seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
264+
ckpt.load_tgt(ctx_tgt, seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
265265

266266
llama_memory_seq_rm(llama_get_memory(ctx_tgt), seq_id, ckpt.pos_max + 1, -1);
267267
}
268268

269269
{
270-
ckpt.load_dft(ctx_dft.get(), seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
270+
ckpt.load_dft(ctx_dft.get(), seq_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
271271

272272
llama_memory_seq_rm(llama_get_memory(ctx_dft.get()), seq_id, ckpt.pos_max + 1, -1);
273273
}

tools/server/server-context.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2512,7 +2512,7 @@ struct server_context_impl {
25122512
llama_memory_seq_pos_max(llama_get_memory(ctx_tgt), slot.id));
25132513

25142514
if (use_ckpt_dft) {
2515-
slot.spec_ckpt.update_dft(ctx_dft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
2515+
slot.spec_ckpt.update_dft(ctx_dft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
25162516
}
25172517

25182518
slot.spec_prompt = slot.prompt.tokens.get_text_tokens();
@@ -2551,7 +2551,7 @@ struct server_context_impl {
25512551

25522552
if (ctx_dft) {
25532553
if (use_ckpt_dft) {
2554-
ckpt.load_dft(ctx_dft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
2554+
ckpt.load_dft(ctx_dft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
25552555
}
25562556

25572557
common_context_seq_rm(ctx_dft.get(), slot.id, ckpt.pos_max + 1, -1);
@@ -2568,7 +2568,7 @@ struct server_context_impl {
25682568
if (use_ckpt_tgt) {
25692569
//const int64_t t_start = ggml_time_us();
25702570

2571-
ckpt.update_tgt(ctx_tgt, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
2571+
ckpt.update_tgt(ctx_tgt, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
25722572

25732573
//const int64_t t_total = ggml_time_us() - t_start;
25742574
//printf("checkpoint total: %f ms\n", t_total / 1000.0);
@@ -2580,7 +2580,7 @@ struct server_context_impl {
25802580
}
25812581

25822582
if (use_ckpt_dft) {
2583-
ckpt.update_dft(ctx_dft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
2583+
ckpt.update_dft(ctx_dft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
25842584
}
25852585
}
25862586
}
@@ -3447,13 +3447,13 @@ struct server_context_impl {
34473447
SLT_DBG(slot, "restoring speculative checkpoint (pos_min = %d, pos_max = %d, size = %zu)\n", ckpt.pos_min, ckpt.pos_max, ckpt.size());
34483448

34493449
{
3450-
ckpt.load_tgt(slot.ctx_tgt, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
3450+
ckpt.load_tgt(slot.ctx_tgt, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
34513451

34523452
common_context_seq_rm(slot.ctx_tgt, slot.id, ckpt.pos_max + 1, -1);
34533453
}
34543454

34553455
if (slot.ctx_dft) {
3456-
ckpt.load_dft(slot.ctx_dft, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY | LLAMA_STATE_SEQ_FLAGS_ON_DEVICE);
3456+
ckpt.load_dft(slot.ctx_dft, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
34573457

34583458
common_context_seq_rm(slot.ctx_dft, slot.id, ckpt.pos_max + 1, -1);
34593459
}

0 commit comments

Comments
 (0)