Skip to content

Commit 4f28cc8

Browse files
author
lvyichen
committed
feat: mtp support swa kvcache rollback && prompt cache
1 parent 432cbef commit 4f28cc8

13 files changed

Lines changed: 642 additions & 328 deletions

common/speculative.cpp

Lines changed: 448 additions & 151 deletions
Large diffs are not rendered by default.

common/speculative.h

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,28 @@ void common_speculative_free(common_speculative * spec);
2626

2727
// optionally call once at the beginning of a new generation
2828
void common_speculative_begin(common_speculative * spec, const llama_tokens & prompt);
29+
// starts a new generation while preserving at most the retained common prefix that is
30+
// still valid in both the target and draft contexts
31+
void common_speculative_begin(
32+
common_speculative * spec,
33+
const llama_tokens & prompt,
34+
llama_pos retained_prefix_len);
35+
36+
llama_pos common_speculative_get_committed_prefix_len(
37+
const common_speculative * spec);
38+
39+
void common_speculative_invalidate_retained_state(
40+
common_speculative * spec);
2941

30-
void common_speculative_set_prompt_hidden_states(
42+
// supplies the token/hidden-state source used by the next MTP first pass; start_pos
43+
// is the target-context position of source_tokens[0]
44+
void common_speculative_set_first_pass_source(
3145
common_speculative * spec,
32-
const float * hidden_states,
33-
int32_t n_tokens,
34-
int32_t n_embd);
46+
const llama_tokens & source_tokens,
47+
const float * hidden_states,
48+
int32_t n_tokens,
49+
int32_t n_embd,
50+
llama_pos start_pos);
3551

3652
// sample up to n_draft tokens and add them to the batch using the draft model
3753
llama_tokens common_speculative_draft(
@@ -40,7 +56,8 @@ llama_tokens common_speculative_draft(
4056
const llama_tokens & prompt,
4157
llama_token id_last);
4258

43-
// informs the speculative decoder that n_accepted tokens were accepted by the target model
59+
// informs the speculative decoder that n_accepted tokens were accepted by the target model;
60+
// batch_idxs maps the frontier token and accepted draft tokens back to verifier output rows
4461
void common_speculative_accept(common_speculative * spec, uint16_t n_accepted, const std::vector<int32_t> & batch_idxs);
4562

4663
// print statistics about the speculative decoding

examples/speculative-simple/speculative-simple.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,11 +163,13 @@ int main(int argc, char ** argv) {
163163
std::memcpy(prompt_hidden.data() + i*llama_model_n_embd(model_tgt), hidden,
164164
llama_model_n_embd(model_tgt)*sizeof(float));
165165
}
166-
common_speculative_set_prompt_hidden_states(
166+
common_speculative_set_first_pass_source(
167167
spec,
168+
prompt_tgt,
168169
prompt_hidden.data(),
169170
prompt_tgt.size(),
170-
llama_model_n_embd(model_tgt));
171+
llama_model_n_embd(model_tgt),
172+
0);
171173
}
172174

173175
llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, 1);

src/llama-batch.cpp

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,7 @@ bool llama_batch_allocr::init(
2828
const llama_memory_i * memory,
2929
uint32_t n_embd,
3030
uint32_t n_seq_max,
31-
bool output_all,
32-
bool allow_non_contiguous_pos) {
31+
bool output_all) {
3332
clear();
3433

3534
batch = batch_inp;
@@ -314,11 +313,9 @@ bool llama_batch_allocr::init(
314313
}
315314
}
316315

317-
if (!allow_non_contiguous_pos) {
318-
if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) {
319-
LLAMA_LOG_ERROR("%s: sequence %d positions are not continuous\n", __func__, s);
320-
return false;
321-
}
316+
if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) {
317+
LLAMA_LOG_ERROR("%s: sequence %d positions are not continuous\n", __func__, s);
318+
return false;
322319
}
323320
}
324321
}

src/llama-batch.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,7 @@ class llama_batch_allocr {
8181
const llama_memory_i * memory,
8282
uint32_t n_embd,
8383
uint32_t n_seq_max,
84-
bool output_all,
85-
bool allow_non_contiguous_pos = false);
84+
bool output_all);
8685

8786
const llama_batch & get_batch() const;
8887

src/llama-context.cpp

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
#include "llama-impl.h"
55
#include "llama-batch.h"
66
#include "llama-io.h"
7-
#include "llama-kv-cache-iswa.h"
87
#include "llama-memory.h"
98
#include "llama-mmap.h"
109
#include "llama-model.h"
@@ -1589,10 +1588,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
15891588
}
15901589
}
15911590

1592-
const llama_memory_i * memory_for_batch = memory.get();
1593-
const bool allow_non_contiguous_pos = false;
1594-
1595-
if (!balloc->init(batch_inp, vocab, memory_for_batch, n_embd, n_seq_max, output_all, allow_non_contiguous_pos)) {
1591+
if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, n_seq_max, output_all)) {
15961592
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
15971593
return -1;
15981594
}
@@ -1748,10 +1744,8 @@ int llama_context::decode(const llama_batch & batch_inp) {
17481744
t_embd = res->get_embd_pooled();
17491745
}
17501746

1751-
const bool mtp_skip_output = false;
1752-
17531747
// extract logits
1754-
if (!mtp_skip_output && logits.data && t_logits && n_outputs > 0 && needs_raw_logits(ubatch, sampling.samplers)) {
1748+
if (logits.data && t_logits && n_outputs > 0 && needs_raw_logits(ubatch, sampling.samplers)) {
17551749
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits);
17561750
GGML_ASSERT(backend_res != nullptr);
17571751
GGML_ASSERT(logits.data != nullptr);
@@ -1766,7 +1760,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
17661760
}
17671761

17681762
// extract embeddings
1769-
if (!mtp_skip_output && embd.data && t_embd && n_outputs > 0) {
1763+
if (embd.data && t_embd && n_outputs > 0) {
17701764
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
17711765
GGML_ASSERT(backend_embd != nullptr);
17721766

src/llama-graph.cpp

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -502,15 +502,14 @@ bool llm_graph_input_attn_k::can_reuse(const llm_graph_params & params) {
502502
}
503503

504504
void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) {
505-
// base tensors may not be allocated if the graph uses only SWA layers
505+
// In single-layer ISWA graphs, one branch can be pruned and never get a backend buffer.
506506
if (self_k_idxs && self_k_idxs->buffer) {
507507
mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch);
508508
mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);
509509

510510
mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
511511
}
512512

513-
// swa tensors may not be allocated if the graph uses only base layers
514513
if (self_k_idxs_swa && self_k_idxs_swa->buffer) {
515514
mctx->get_swa()->set_input_k_idxs(self_k_idxs_swa, ubatch);
516515
mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch);
@@ -534,21 +533,14 @@ bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) {
534533

535534
bool res = true;
536535

537-
// base tensors may not be allocated if the graph uses only SWA layers
538-
if (self_k_idxs && self_k_idxs->buffer) {
539-
res &= self_k_idxs->ne[0] == params.ubatch.n_tokens;
540-
//res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
541-
542-
res &= can_reuse_kq_mask(self_kq_mask, mctx->get_base(), params.ubatch, params.cparams);
543-
}
536+
res &= self_k_idxs->ne[0] == params.ubatch.n_tokens;
537+
//res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
544538

545-
// swa tensors may not be allocated if the graph uses only base layers
546-
if (self_k_idxs_swa && self_k_idxs_swa->buffer) {
547-
res &= self_k_idxs_swa->ne[0] == params.ubatch.n_tokens;
548-
//res &= self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
539+
res &= self_k_idxs_swa->ne[0] == params.ubatch.n_tokens;
540+
//res &= self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
549541

550-
res &= can_reuse_kq_mask(self_kq_mask_swa, mctx->get_swa(), params.ubatch, params.cparams);
551-
}
542+
res &= can_reuse_kq_mask(self_kq_mask, mctx->get_base(), params.ubatch, params.cparams);
543+
res &= can_reuse_kq_mask(self_kq_mask_swa, mctx->get_swa(), params.ubatch, params.cparams);
552544

553545
return res;
554546
}

src/llama-kv-cache-iswa.cpp

Lines changed: 14 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -209,38 +209,6 @@ llama_memory_context_ptr llama_kv_cache_iswa::init_batch(llama_batch_allocr & ba
209209
return std::make_unique<llama_kv_cache_iswa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
210210
}
211211

212-
llama_memory_context_ptr llama_kv_cache_iswa::init_batch_with_sinfos(
213-
llama_batch_allocr & balloc,
214-
uint32_t n_ubatch,
215-
const llama_kv_cache::slot_info_vec_t & sinfos,
216-
bool is_inplace_update) {
217-
if (sinfos.empty()) {
218-
return std::make_unique<llama_kv_cache_iswa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
219-
}
220-
221-
balloc.split_reset();
222-
223-
std::vector<llama_ubatch> ubatches;
224-
const uint32_t n_stream = kv_base->get_n_stream();
225-
while (true) {
226-
auto ubatch = n_stream == 1 ? balloc.split_simple(n_ubatch) : balloc.split_equal(n_ubatch, true);
227-
if (ubatch.n_tokens == 0) {
228-
break;
229-
}
230-
ubatches.push_back(std::move(ubatch)); // NOLINT
231-
}
232-
233-
if (ubatches.size() != sinfos.size()) {
234-
return std::make_unique<llama_kv_cache_iswa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
235-
}
236-
237-
auto sinfos_base = sinfos;
238-
auto sinfos_swa = sinfos;
239-
240-
return std::make_unique<llama_kv_cache_iswa_context>(
241-
this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches), is_inplace_update);
242-
}
243-
244212
llama_memory_context_ptr llama_kv_cache_iswa::init_full() {
245213
return std::make_unique<llama_kv_cache_iswa_context>(this);
246214
}
@@ -279,6 +247,20 @@ llama_kv_cache * llama_kv_cache_iswa::get_swa() const {
279247
return kv_swa.get();
280248
}
281249

250+
void llama_kv_cache_iswa::set_swa_reuse_guard(llama_pos query_pos) {
251+
kv_base->clear_swa_reuse_guard();
252+
kv_swa->set_swa_reuse_guard(query_pos);
253+
}
254+
255+
void llama_kv_cache_iswa::clear_swa_reuse_guard() {
256+
kv_base->clear_swa_reuse_guard();
257+
kv_swa->clear_swa_reuse_guard();
258+
}
259+
260+
bool llama_kv_cache_iswa::consume_swa_reuse_guard_block_prepare() {
261+
return kv_swa->consume_swa_reuse_guard_block_prepare();
262+
}
263+
282264
//
283265
// llama_kv_cache_iswa_context
284266
//
@@ -313,19 +295,6 @@ llama_kv_cache_iswa_context::llama_kv_cache_iswa_context(
313295
status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
314296
}
315297

316-
llama_kv_cache_iswa_context::llama_kv_cache_iswa_context(
317-
llama_kv_cache_iswa * kv,
318-
slot_info_vec_t sinfos_base,
319-
slot_info_vec_t sinfos_swa,
320-
std::vector<llama_ubatch> ubatches,
321-
bool is_inplace_update) :
322-
ubatches(std::move(ubatches)),
323-
// note: here we copy the ubatches. not sure if this is ideal
324-
ctx_base(new llama_kv_cache_context(kv->get_base(), std::move(sinfos_base), this->ubatches, is_inplace_update)),
325-
ctx_swa (new llama_kv_cache_context(kv->get_swa (), std::move(sinfos_swa), this->ubatches, is_inplace_update)),
326-
status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
327-
}
328-
329298
llama_kv_cache_iswa_context:: ~llama_kv_cache_iswa_context() = default;
330299

331300
bool llama_kv_cache_iswa_context::next() {
@@ -373,12 +342,3 @@ const llama_kv_cache_context * llama_kv_cache_iswa_context::get_swa() const {
373342

374343
return static_cast<const llama_kv_cache_context *>(ctx_swa.get());
375344
}
376-
377-
void llama_kv_cache_iswa_context::set_inplace(bool value) {
378-
auto * base = const_cast<llama_kv_cache_context *>(
379-
static_cast<const llama_kv_cache_context *>(ctx_base.get()));
380-
auto * swa = const_cast<llama_kv_cache_context *>(
381-
static_cast<const llama_kv_cache_context *>(ctx_swa.get()));
382-
if (base) { base->set_inplace(value); }
383-
if (swa) { swa ->set_inplace(value); }
384-
}

src/llama-kv-cache-iswa.h

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,6 @@ class llama_kv_cache_iswa : public llama_memory_i {
3939
uint32_t n_ubatch,
4040
bool embd_all) override;
4141

42-
llama_memory_context_ptr init_batch_with_sinfos(
43-
llama_batch_allocr & balloc,
44-
uint32_t n_ubatch,
45-
const llama_kv_cache::slot_info_vec_t & sinfos,
46-
bool is_inplace_update);
47-
4842
llama_memory_context_ptr init_full() override;
4943

5044
llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
@@ -76,6 +70,11 @@ class llama_kv_cache_iswa : public llama_memory_i {
7670
llama_kv_cache * get_base() const;
7771
llama_kv_cache * get_swa () const;
7872

73+
void set_swa_reuse_guard(llama_pos query_pos);
74+
void clear_swa_reuse_guard();
75+
76+
bool consume_swa_reuse_guard_block_prepare();
77+
7978
private:
8079
const llama_hparams & hparams;
8180

@@ -108,12 +107,6 @@ class llama_kv_cache_iswa_context : public llama_memory_context_i {
108107
slot_info_vec_t sinfos_base,
109108
slot_info_vec_t sinfos_swa,
110109
std::vector<llama_ubatch> ubatches);
111-
llama_kv_cache_iswa_context(
112-
llama_kv_cache_iswa * kv,
113-
slot_info_vec_t sinfos_base,
114-
slot_info_vec_t sinfos_swa,
115-
std::vector<llama_ubatch> ubatches,
116-
bool is_inplace_update);
117110

118111
virtual ~llama_kv_cache_iswa_context();
119112

@@ -134,8 +127,6 @@ class llama_kv_cache_iswa_context : public llama_memory_context_i {
134127
const llama_kv_cache_context * get_base() const;
135128
const llama_kv_cache_context * get_swa() const;
136129

137-
void set_inplace(bool value);
138-
139130
private:
140131
//llama_kv_cache_iswa * kv;
141132

0 commit comments

Comments
 (0)