Skip to content

Commit 61f6090

Browse files
Merge pull request #491 from janhq/update-dev-from-master-2026-04-20-00-58
Sync master with upstream release b8851
2 parents 8f3cfc6 + e365e65 commit 61f6090

33 files changed

Lines changed: 1743 additions & 1645 deletions

.github/workflows/build-cross.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,7 @@ jobs:
246246
apt-get install -y --no-install-recommends \
247247
build-essential \
248248
glslc \
249+
spirv-headers \
249250
gcc-14-loongarch64-linux-gnu \
250251
g++-14-loongarch64-linux-gnu \
251252
libvulkan-dev:loong64

common/chat-auto-parser-generator.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -443,14 +443,14 @@ common_peg_parser analyze_tools::build_tool_parser_tag_tagged(parser_build_conte
443443
if (!format.per_call_start.empty()) {
444444
auto wrapped_call = format.per_call_start + p.space() + tool_choice + p.space() + format.per_call_end;
445445
if (inputs.parallel_tool_calls) {
446-
tool_calls = p.trigger_rule("tool-call", wrapped_call + p.zero_or_more(p.space() + wrapped_call));
446+
tool_calls = p.trigger_rule("tool-call", wrapped_call + p.zero_or_more(p.space() + wrapped_call) + p.space());
447447
} else {
448-
tool_calls = p.trigger_rule("tool-call", wrapped_call);
448+
tool_calls = p.trigger_rule("tool-call", wrapped_call + p.space());
449449
}
450450
if (!format.section_start.empty()) {
451451
tool_calls = p.trigger_rule("tool-calls",
452452
p.literal(format.section_start) + p.space() + tool_calls + p.space() +
453-
(format.section_end.empty() ? p.end() : p.literal(format.section_end)));
453+
(format.section_end.empty() ? p.end() : p.literal(format.section_end) + p.space()));
454454
}
455455
} else {
456456
std::string separator = ", "; // Default

common/chat.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2334,7 +2334,7 @@ common_chat_msg common_chat_peg_parse(const common_peg_arena & src_pars
23342334
? input
23352335
: params.generation_prompt + input;
23362336

2337-
LOG_DBG("Parsing PEG input with format %s: %s\n", common_chat_format_name(params.format), effective_input.c_str());
2337+
//LOG_DBG("Parsing PEG input with format %s: %s\n", common_chat_format_name(params.format), effective_input.c_str());
23382338

23392339
common_peg_parse_flags flags = COMMON_PEG_PARSE_FLAG_LENIENT;
23402340
if (params.debug) {

common/common.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
#include <sstream>
1212
#include <string>
1313
#include <string_view>
14-
#include <variant>
1514
#include <vector>
1615
#include <map>
1716

@@ -303,7 +302,7 @@ struct common_params_speculative {
303302
// general-purpose speculative decoding parameters
304303

305304
int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding
306-
int32_t n_min = 0; // minimum number of draft tokens to use for speculative decoding
305+
int32_t n_min = 0; // minimum number of draft tokens to use for speculative decoding
307306
float p_split = 0.1f; // speculative decoding split probability
308307
float p_min = 0.75f; // minimum speculative decoding probability (greedy)
309308

@@ -312,6 +311,7 @@ struct common_params_speculative {
312311
uint16_t ngram_size_n = 12; // ngram size for lookup
313312
uint16_t ngram_size_m = 48; // mgram size for speculative tokens
314313
uint16_t ngram_min_hits = 1; // minimum hits at ngram/mgram lookup for mgram to be proposed
314+
bool use_checkpoints = false; // use checkpoints to rewind in token history of recurrent models
315315

316316
std::shared_ptr<common_ngram_mod> ngram_mod;
317317

common/ngram-map.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ void common_ngram_map_begin(
208208
count_keys, count_keys_del, count_values_del, count_map_entries_upd);
209209
}
210210

211-
map.idx_last_check = (map.size_last_begin > 0) ? map.size_last_begin - 1 : 0;
211+
map.idx_last_check = size_begin;
212212
map.size_last_begin = size_begin;
213213
}
214214

@@ -231,7 +231,7 @@ void common_ngram_map_draft(common_ngram_map & map,
231231
GGML_ABORT("%s: cur_len exceeds UINT32_MAX: %zu", __func__, cur_len);
232232
}
233233

234-
if (map.idx_last_check > cur_len) {
234+
if (map.idx_last_check > cur_len) {
235235
// Should not happen because of common_ngram_map_begin().
236236
GGML_ABORT("%s: map.idx_last_check > cur_len: %zu > %zu", __func__, map.idx_last_check, cur_len);
237237
}
@@ -386,7 +386,7 @@ void common_ngram_map_draft(common_ngram_map & map,
386386
LOG_DBG("%s: key_idx = %zu, key_offset = %zu, key_num = %d, draft.size = %zu\n", __func__,
387387
curr_key.key_idx, key_offset, curr_key.key_num, draft.size());
388388

389-
map.last_draft_created = false;
389+
map.last_draft_created = true;
390390
map.last_draft_key_idx = key_offset;
391391
map.last_draft_value_idx = 0; // value 0 is used for simple mode
392392
return;
@@ -524,7 +524,7 @@ void common_ngram_map_accept(common_ngram_map & map, uint16_t n_accepted) {
524524
struct common_ngram_map_value & curr_value = curr_key.values[val_idx]; // value used for draft generation.
525525

526526
// update the value statistics
527-
LOG_INF("common_ngram_map_send_accepted: n_accepted = %d, prev value_num = %d\n",
527+
LOG_DBG("common_ngram_map_send_accepted: n_accepted = %d, prev value_num = %d\n",
528528
n_accepted, curr_value.n_accepted);
529529
curr_value.n_accepted = n_accepted;
530530
}

common/speculative.cpp

Lines changed: 139 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include <cstring>
1414
#include <iomanip>
1515
#include <map>
16+
#include <cinttypes>
1617

1718
#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128
1819
#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
@@ -144,10 +145,28 @@ struct common_speculative_state {
144145
virtual void accept(uint16_t n_accepted) = 0;
145146
};
146147

148+
struct common_speculative_checkpoint {
149+
llama_pos pos_min = 0;
150+
llama_pos pos_max = 0;
151+
152+
int64_t n_tokens = 0;
153+
154+
std::vector<uint8_t> data;
155+
156+
size_t size() const {
157+
return data.size();
158+
}
159+
160+
size_t ckpt_size = 0;
161+
};
162+
147163
struct common_speculative_state_draft : public common_speculative_state {
148164
llama_context * ctx_tgt; // only used for retokenizing from ctx_dft
149165
llama_context * ctx_dft;
150166

167+
struct common_speculative_checkpoint ckpt;
168+
bool use_checkpoint;
169+
151170
common_sampler * smpl;
152171

153172
llama_batch batch;
@@ -160,10 +179,12 @@ struct common_speculative_state_draft : public common_speculative_state {
160179
enum common_speculative_type type,
161180
llama_context * ctx_tgt,
162181
llama_context * ctx_dft,
163-
const std::vector<std::pair<std::string, std::string>> & replacements)
182+
const std::vector<std::pair<std::string, std::string>> & replacements,
183+
bool use_checkpoint)
164184
: common_speculative_state(type)
165185
, ctx_tgt(ctx_tgt)
166186
, ctx_dft(ctx_dft)
187+
, use_checkpoint(use_checkpoint)
167188
{
168189
batch = llama_batch_init(llama_n_batch(ctx_dft), 0, 1);
169190
smpl = nullptr;
@@ -218,7 +239,48 @@ struct common_speculative_state_draft : public common_speculative_state {
218239
}
219240

220241
void begin(const llama_tokens & prompt) override {
221-
GGML_UNUSED(prompt);
242+
if (use_checkpoint && ckpt.size() > 0) {
243+
// delete checkpoint
244+
LOG_DBG("%s: delete checkpoint, prompt.size=%zu, pos_min=%d, pos_max=%d, n_tokens=%" PRId64 ", size=%.3f MiB\n",
245+
__func__, prompt.size(), ckpt.pos_min, ckpt.pos_max, ckpt.n_tokens, (float) ckpt.data.size() / 1024 / 1024);
246+
ckpt.pos_min = 0;
247+
ckpt.pos_max = 0;
248+
ckpt.n_tokens = 0;
249+
ckpt.ckpt_size = 0;
250+
ckpt.data.clear();
251+
}
252+
}
253+
254+
size_t draft_create_checkpoint(int n_tokens_prompt, int n_tokens_batch) {
255+
int slot_id = 0;
256+
const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx_dft, slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
257+
258+
ckpt.pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_dft), slot_id);
259+
ckpt.pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), slot_id);
260+
ckpt.n_tokens = n_tokens_prompt - n_tokens_batch;
261+
ckpt.data.resize(checkpoint_size);
262+
263+
const size_t n = llama_state_seq_get_data_ext(ctx_dft, ckpt.data.data(), checkpoint_size, slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
264+
if (n != checkpoint_size) {
265+
GGML_ABORT("checkpoint size mismatch: expected %zu, got %zu\n", checkpoint_size, n);
266+
}
267+
268+
LOG_DBG("%s: pos_min = %d, pos_max = %d, size = %.3f MiB\n", __func__,
269+
ckpt.pos_min, ckpt.pos_max, (float) ckpt.data.size() / 1024 / 1024);
270+
return n;
271+
}
272+
273+
size_t draft_restore_checkpoint(size_t ckpt_size_part_expected) {
274+
int slot_id = 0;
275+
LOG_DBG("%s: pos_min = %d, pos_max = %d\n", __func__, ckpt.pos_min, ckpt.pos_max);
276+
const size_t n = llama_state_seq_set_data_ext(ctx_dft, ckpt.data.data(), ckpt.size(), slot_id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
277+
if (n != ckpt_size_part_expected) {
278+
GGML_ABORT("%s: failed to restore context checkpoint (pos_min=%d, pos_max=%d, size=%zu, get_data_ext->%zu, set_data_ext->%zu",
279+
__func__, ckpt.pos_min, ckpt.pos_max, ckpt.size(), ckpt_size_part_expected, n);
280+
}
281+
llama_memory_seq_rm(llama_get_memory(ctx_dft), slot_id, ckpt.pos_max + 1, -1);
282+
283+
return n;
222284
}
223285

224286
void draft(
@@ -236,8 +298,8 @@ struct common_speculative_state_draft : public common_speculative_state {
236298

237299
auto * mem_dft = llama_get_memory(ctx_dft);
238300

239-
int reuse_i = 0;
240-
int reuse_n = 0;
301+
int reuse_i = 0; // index of part to be reused in prompt_dft
302+
int reuse_n = 0; // length of part to be reused in prompt_dft
241303

242304
const int n_ctx = llama_n_ctx(ctx_dft) - params.n_max;
243305

@@ -287,18 +349,26 @@ struct common_speculative_state_draft : public common_speculative_state {
287349
}
288350
}
289351

290-
LOG_DBG("%s: reuse_i = %d, reuse_n = %d, prompt = %d\n", __func__, reuse_i, reuse_n, (int) prompt_dft.size());
352+
LOG_DBG("%s: reuse_i = %d, reuse_n = %d, #prompt_dft = %zu, #prompt_cur = %zu\n",
353+
__func__, reuse_i, reuse_n, prompt_dft.size(), prompt_cur.size());
354+
if (use_checkpoint && ckpt.ckpt_size == 0 && reuse_n > 0) {
355+
LOG_DBG("%s: no checkpoint available, no reuse, (reuse_i=%d, reuse_n=%d) -> (0, 0)\n",
356+
__func__, reuse_i, reuse_n);
357+
reuse_i = 0;
358+
reuse_n = 0;
359+
}
291360

292361
result.clear();
293362
result.reserve(params.n_max);
294363

295-
if (reuse_n == 0) {
364+
bool needs_ckpt = use_checkpoint && prompt_dft.size() > 0;
365+
if (reuse_n == 0 || (use_checkpoint && reuse_i > 0)) {
296366
llama_memory_clear(mem_dft, false);
297367
prompt_dft.clear();
298368
} else {
299369
// this happens when a previous draft has been discarded (for example, due to being too small), but the
300370
// target model agreed with it. in this case, we simply pass back the previous results to save compute
301-
if (reuse_i + reuse_n < (int) prompt_dft.size() && prompt_dft[reuse_i + reuse_n] == id_last) {
371+
if (reuse_i + reuse_n < (int64_t) prompt_dft.size() && prompt_dft[reuse_i + reuse_n] == id_last) {
302372
for (int i = reuse_i + reuse_n + 1; i < (int) prompt_dft.size(); ++i) {
303373
result.push_back(prompt_dft[i]);
304374

@@ -310,19 +380,50 @@ struct common_speculative_state_draft : public common_speculative_state {
310380
return;
311381
}
312382

383+
bool do_restore = false;
384+
if (prompt_dft.size() > prompt_cur.size() && reuse_i + reuse_n < (int64_t) prompt_dft.size()) {
385+
// This can happen after a partial acceptance (speculative decoding with checkpoints)
386+
LOG_DBG("%s: #prompt_dft=%zu, #prompt_cur=%zu, shorten draft\n",
387+
__func__, prompt_dft.size(), prompt_cur.size());
388+
prompt_dft.resize(prompt_cur.size());
389+
do_restore = true;
390+
}
391+
313392
if (reuse_i > 0) {
314-
llama_memory_seq_rm (mem_dft, 0, 0, reuse_i);
393+
bool is_removed = llama_memory_seq_rm (mem_dft, 0, 0, reuse_i);
394+
if (!is_removed) {
395+
LOG_ERR("%s: llama_memory_seq_rm failed, reuse_i=%d\n", __func__, reuse_i);
396+
}
315397
llama_memory_seq_add(mem_dft, 0, reuse_i, -1, -reuse_i);
316398

317399
prompt_dft.erase(prompt_dft.begin(), prompt_dft.begin() + reuse_i);
318400
}
319401

320-
if (reuse_n < (int) prompt_dft.size()) {
321-
llama_memory_seq_rm (mem_dft, 0, reuse_n, -1);
322-
prompt_dft.erase(prompt_dft.begin() + reuse_n, prompt_dft.end());
402+
if (reuse_n < (int) prompt_dft.size() || do_restore) {
403+
if (use_checkpoint) {
404+
if (ckpt.n_tokens > (int64_t) prompt_dft.size()) {
405+
LOG_INF("%s: checkpoint is too large, prompt_tgt.size=%zu, ckpt.n_tokens=%" PRId64 ", reuse_n=%d, prompt_dft.size=%zu\n",
406+
__func__, prompt_tgt.size(), ckpt.n_tokens, reuse_n, prompt_dft.size());
407+
}
408+
draft_restore_checkpoint(ckpt.ckpt_size);
409+
reuse_n = ckpt.n_tokens;
410+
prompt_dft.resize(reuse_n);
411+
needs_ckpt = false;
412+
} else {
413+
bool is_removed = llama_memory_seq_rm (mem_dft, 0, reuse_n, -1);
414+
if (!is_removed) {
415+
LOG_ERR("%s: llama_memory_seq_rm failed, reuse_n=%d, prompt_dft.size=%zu\n",
416+
__func__, reuse_n, prompt_dft.size());
417+
}
418+
prompt_dft.erase(prompt_dft.begin() + reuse_n, prompt_dft.end());
419+
}
323420
}
324421
}
325422

423+
if (needs_ckpt) {
424+
ckpt.ckpt_size = draft_create_checkpoint(prompt_dft.size(), batch.n_tokens);
425+
}
426+
326427
// prepare a batch to evaluate any new tokens in the prompt
327428
common_batch_clear(batch);
328429

@@ -337,7 +438,11 @@ struct common_speculative_state_draft : public common_speculative_state {
337438
if (batch.n_tokens > 0) {
338439
//LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str());
339440

340-
llama_decode(ctx_dft, batch);
441+
int ret = llama_decode(ctx_dft, batch);
442+
if (ret != 0 && ret != 1) {
443+
LOG_WRN("%s: llama_decode returned %d, prompt_cur.size=%zu\n",
444+
__func__, ret, prompt_cur.size());
445+
}
341446
}
342447

343448
const llama_pos n_past = prompt_dft.size();
@@ -351,7 +456,11 @@ struct common_speculative_state_draft : public common_speculative_state {
351456

352457
LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx_dft, prompt_dft).c_str());
353458

354-
llama_decode(ctx_dft, batch);
459+
int ret = llama_decode(ctx_dft, batch);
460+
if (ret != 0 && ret != 1) {
461+
LOG_WRN("%s: llama_decode returned %d, prompt_cur.size=%zu, prompt_dft.size=%zu\n",
462+
__func__, ret, prompt_cur.size(), prompt_dft.size());
463+
}
355464

356465
common_sampler_reset(smpl);
357466

@@ -387,7 +496,11 @@ struct common_speculative_state_draft : public common_speculative_state {
387496
common_batch_add(batch, id, n_past + i + 1, { 0 }, true);
388497

389498
// evaluate the drafted tokens on the draft model
390-
llama_decode(ctx_dft, batch);
499+
ret = llama_decode(ctx_dft, batch);
500+
if (ret != 0) {
501+
LOG_WRN("%s: llama_decode[%d] returned %d, prompt_cur.size=%zu, prompt_dft.size=%zu\n",
502+
__func__, i, ret, prompt_cur.size(), prompt_dft.size());
503+
}
391504

392505
prompt_dft.push_back(id);
393506
}
@@ -739,6 +852,7 @@ struct common_speculative_state_ngram_cache : public common_speculative_state {
739852

740853
struct common_speculative {
741854
std::vector<std::unique_ptr<common_speculative_state>> impls; // list of implementations to use and their states
855+
742856
common_speculative_state * curr_impl = nullptr; // current implementation in use (for stats)
743857
};
744858

@@ -798,13 +912,13 @@ enum common_speculative_type common_speculative_type_from_name(const std::string
798912
return it->second;
799913
}
800914

801-
bool common_speculative_is_compat(llama_context * ctx_tgt) {
915+
common_speculative_compat_type common_speculative_is_compat(llama_context * ctx_tgt) {
802916
auto * mem = llama_get_memory(ctx_tgt);
803917
if (mem == nullptr) {
804-
return false;
918+
return COMMON_SPECULATIVE_COMPAT_TYPE_NO;
805919
}
806920

807-
bool res = true;
921+
common_speculative_compat_type res = COMMON_SPECULATIVE_COMPAT_TYPE_FULL;
808922

809923
llama_memory_clear(mem, true);
810924

@@ -816,14 +930,14 @@ bool common_speculative_is_compat(llama_context * ctx_tgt) {
816930
int ret = llama_decode(ctx_tgt, llama_batch_get_one(tmp.data(), tmp.size()));
817931
if (ret != 0) {
818932
LOG_ERR("%s: llama_decode() failed: %d\n", __func__, ret);
819-
res = false;
933+
res = COMMON_SPECULATIVE_COMPAT_TYPE_NO;
820934
goto done;
821935
}
822936

823937
// try to remove the last tokens
824938
if (!llama_memory_seq_rm(mem, 0, 1, -1)) {
825939
LOG_WRN("%s: the target context does not support partial sequence removal\n", __func__);
826-
res = false;
940+
res = COMMON_SPECULATIVE_COMPAT_TYPE_CKPT;
827941
goto done;
828942
}
829943

@@ -909,9 +1023,10 @@ common_speculative * common_speculative_init(
9091023
break;
9101024
case COMMON_SPECULATIVE_TYPE_DRAFT: {
9111025
impls.push_back(std::make_unique<common_speculative_state_draft>(config.type,
912-
/* .ctx_tgt = */ ctx_tgt,
913-
/* .ctx_dft = */ ctx_dft,
914-
/* .replacements = */ params.replacements
1026+
/* .ctx_tgt = */ ctx_tgt,
1027+
/* .ctx_dft = */ ctx_dft,
1028+
/* .replacements = */ params.replacements,
1029+
/* .use_checkpoint= */ params.use_checkpoints // TODO: this should be based on the draft model!
9151030
));
9161031
break;
9171032
}
@@ -966,7 +1081,8 @@ common_speculative * common_speculative_init(
9661081
}
9671082

9681083
auto * result = new common_speculative {
969-
/* .impls = */ std::move(impls)
1084+
/* .impls = */ std::move(impls),
1085+
/* .curr_impl = */ nullptr,
9701086
};
9711087

9721088
return result;

0 commit comments

Comments
 (0)