Skip to content

Commit 5fba55e

Browse files
committed
multistep process
1 parent 114b4dd commit 5fba55e

1 file changed

Lines changed: 95 additions & 1 deletion

File tree

common/speculative.cpp

Lines changed: 95 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -549,10 +549,19 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
549549

550550
const size_t row_bytes = (size_t) n_embd * sizeof(float);
551551

552+
// Stacked MTP needs each block's KV pre-populated with the chain
553+
// context, not just MTP1. For multi-block archs we request logits at
554+
// every batch position so the masked-mode hidden-state extraction in
555+
// ctx_dft captures every row (we feed those rows into the next chain
556+
// step). The extra LM-head matmuls only fire during prefill / verify,
557+
// not during drafting, so the cost is bounded.
558+
const bool chain_prefill = (n_mtp_layers > 1);
559+
const int8_t want_logits = chain_prefill ? 1 : 0;
560+
552561
common_batch_clear(batch);
553562

554563
for (int k = 0; k < n_tokens; ++k) {
555-
common_batch_add(batch, batch_in.token[k], batch_in.pos[k], { batch_in.seq_id[k][0] }, 0);
564+
common_batch_add(batch, batch_in.token[k], batch_in.pos[k], { batch_in.seq_id[k][0] }, want_logits);
556565
}
557566

558567
// shift the tgt embeddings to the right by one position
@@ -587,12 +596,97 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
587596
set_h(i_batch_beg[seq_id], pending_h[seq_id].data());
588597
}
589598

599+
llama_set_mtp_step(ctx_dft, 0);
590600
const int32_t rc = llama_decode(ctx_dft, batch);
591601
if (rc != 0) {
592602
LOG_ERR("%s: llama_decode(ctx_dft) failed rc=%d (pos=%d)\n", __func__, (int) rc, (int) batch_in.pos[0]);
593603
return false;
594604
}
595605

606+
// For stacked MTP, run the remaining chain steps so blocks 2..N also
607+
// get their KV slots filled with the verified context. Each step uses:
608+
// token[i] = batch_in.token[i+1] (the next trunk token)
609+
// pos [i] = batch_in.pos [i+1] (the next trunk position)
610+
// embd [i] = prev block's hidden at original batch index i
611+
// Each step drops the leading position of every sequence (no prior
612+
// chain hidden available there), so the deeper blocks lose a few
613+
// entries at sequence starts — negligible for prompts of meaningful
614+
// length, and correct: those positions never come up during drafting.
615+
if (chain_prefill) {
616+
// Snapshot MTP_{k-1}'s pre-norm hiddens, indexed by the original
617+
// batch_in position so we can chain across the per-seq remap.
618+
std::vector<float> prev_hiddens((size_t) n_tokens * n_embd);
619+
{
620+
const float * h_dft = llama_get_embeddings_pre_norm(ctx_dft);
621+
if (h_dft != nullptr) {
622+
std::memcpy(prev_hiddens.data(), h_dft, (size_t) n_tokens * row_bytes);
623+
} else {
624+
LOG_WRN("%s: chain prefill skipped (ctx_dft pre-norm embeddings unavailable)\n", __func__);
625+
return true;
626+
}
627+
}
628+
629+
for (int step = 1; step < n_mtp_layers; ++step) {
630+
common_batch_clear(batch);
631+
632+
// Maps step-batch index -> original batch_in index so the
633+
// next iteration can pick this step's hidden by absolute pos.
634+
std::vector<int32_t> step_idx_to_in(n_tokens, -1);
635+
636+
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
637+
const int32_t beg = i_batch_beg[seq_id];
638+
const int32_t end = i_batch_end[seq_id];
639+
if (beg < 0 || end < 0) {
640+
continue;
641+
}
642+
if (end - beg < step) {
643+
continue; // this seq is shorter than the chain depth so far
644+
}
645+
646+
for (int32_t i = beg + step; i <= end; ++i) {
647+
const int32_t prev_idx = i - 1; // MTP_{step-1}'s hidden at the source position
648+
const int32_t step_i = batch.n_tokens;
649+
650+
common_batch_add(batch, batch_in.token[i], batch_in.pos[i], { seq_id }, 1);
651+
std::memcpy(batch.embd + (size_t) step_i * n_embd,
652+
prev_hiddens.data() + (size_t) prev_idx * n_embd,
653+
row_bytes);
654+
655+
step_idx_to_in[step_i] = i;
656+
}
657+
}
658+
659+
if (batch.n_tokens == 0) {
660+
break;
661+
}
662+
663+
llama_set_mtp_step(ctx_dft, (uint32_t) step);
664+
const int32_t rc_step = llama_decode(ctx_dft, batch);
665+
if (rc_step != 0) {
666+
LOG_WRN("%s: chain prefill step %d llama_decode failed rc=%d\n", __func__, step, rc_step);
667+
break;
668+
}
669+
670+
// Gather this step's hidden states for the next iteration,
671+
// remapped to the original batch_in indexing.
672+
if (step + 1 < n_mtp_layers) {
673+
std::vector<float> next_prev((size_t) n_tokens * n_embd, 0.0f);
674+
for (int32_t step_i = 0; step_i < (int32_t) batch.n_tokens; ++step_i) {
675+
const int32_t in_i = step_idx_to_in[step_i];
676+
if (in_i < 0) {
677+
continue;
678+
}
679+
const float * h = llama_get_embeddings_pre_norm_ith(ctx_dft, step_i);
680+
std::memcpy(next_prev.data() + (size_t) in_i * n_embd, h, row_bytes);
681+
}
682+
prev_hiddens = std::move(next_prev);
683+
}
684+
}
685+
686+
// Reset so subsequent draft() starts the chain at block 0.
687+
llama_set_mtp_step(ctx_dft, 0);
688+
}
689+
596690
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
597691
if (i_batch_end[seq_id] < 0) {
598692
continue;

0 commit comments

Comments
 (0)