Skip to content

Commit 3043a4b

Browse files
committed
Simplify to single layer
1 parent d69a72d commit 3043a4b

2 files changed

Lines changed: 33 additions & 19 deletions

File tree

common/speculative.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -704,9 +704,11 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
704704
break;
705705
}
706706

707-
// Step i+1: feed the i-th sampled draft token into the (i+1)-th
708-
// MTP block. Multi-block archs round-robin via mtp_step % N.
709-
llama_set_mtp_step(ctx_dft, (uint32_t)(i + 1));
707+
// Single-block-MTP-only: every AR step reuses the first MTP block
708+
// (Qwen MTP / vLLM single-MTP-layer style). mtp_step stays at 0;
709+
// trailing MTP blocks loaded from the GGUF are ignored at
710+
// runtime, and pruned GGUFs (block 0 only) work the same way.
711+
llama_set_mtp_step(ctx_dft, 0);
710712

711713
// evaluate the drafted tokens on the draft model
712714
ret = llama_decode(ctx_dft, batch);

src/models/step35.cpp

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ void llama_model_step35::load_arch_tensors(llama_model_loader & ml) {
113113
layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, TENSOR_NOT_REQUIRED);
114114
};
115115

116-
auto load_block_mtp = [&](int i) {
116+
auto load_block_mtp = [&](int i, bool is_first_mtp) {
117117
auto & layer = layers[i];
118118

119119
const uint32_t n_head_l = hparams.n_head(i);
@@ -123,7 +123,14 @@ void llama_model_step35::load_arch_tensors(llama_model_loader & ml) {
123123
// The MTP block is a full Step3p5 decoder layer (mtp_block) plus the
124124
// NextN-specific wiring (enorm/hnorm/eh_proj + optional shared head).
125125
// `mtp_flags` becomes NOT_REQUIRED when the GGUF is trunk-only.
126-
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, mtp_flags);
126+
//
127+
// Only the FIRST MTP block (i == n_main) is required for the
128+
// single-block MTP runtime; trailing MTP blocks are always tolerated
129+
// as missing so pruned GGUFs (block 0 only) load cleanly. Override
130+
// mtp_flags to NOT_REQUIRED for those.
131+
const int eff_mtp_flags = is_first_mtp ? mtp_flags : (mtp_flags | TENSOR_NOT_REQUIRED);
132+
133+
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, eff_mtp_flags);
127134
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, TENSOR_NOT_REQUIRED);
128135
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, TENSOR_NOT_REQUIRED);
129136

@@ -134,12 +141,12 @@ void llama_model_step35::load_arch_tensors(llama_model_loader & ml) {
134141
layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot_max/2}, TENSOR_NOT_REQUIRED | TENSOR_DUPLICATED);
135142
}
136143

137-
create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head_l, n_embd_k_gqa, n_embd_v_gqa, mtp_flags);
138-
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_v * n_head_l, n_embd}, mtp_flags);
144+
create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head_l, n_embd_k_gqa, n_embd_v_gqa, eff_mtp_flags);
145+
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_v * n_head_l, n_embd}, eff_mtp_flags);
139146

140147
layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), {n_embd, n_head_l}, TENSOR_NOT_REQUIRED);
141148

142-
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, mtp_flags);
149+
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, eff_mtp_flags);
143150

144151
// dense MLP (leading dense blocks) — present if the MTP block isn't MoE
145152
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED);
@@ -159,9 +166,9 @@ void llama_model_step35::load_arch_tensors(llama_model_loader & ml) {
159166
layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, TENSOR_NOT_REQUIRED);
160167

161168
// NextN-specific tensors that define the MTP block.
162-
layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, mtp_flags);
163-
layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, mtp_flags);
164-
layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, mtp_flags);
169+
layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, eff_mtp_flags);
170+
layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, eff_mtp_flags);
171+
layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, eff_mtp_flags);
165172
layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED);
166173
layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED);
167174
layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED);
@@ -170,8 +177,13 @@ void llama_model_step35::load_arch_tensors(llama_model_loader & ml) {
170177
for (int i = 0; i < (int) n_main; ++i) {
171178
load_block_trunk(i, trunk_flags);
172179
}
180+
// Only the first MTP block (i == n_main) is required at runtime — the
181+
// single-block-MTP graph in build_arch_graph always uses that one.
182+
// Trailing MTP blocks are loaded if present (so an un-pruned GGUF with
183+
// all MTP layers still works) but tolerated when absent via the pruning
184+
// path. See scripts/prune_step35_extra_mtp.py for the pruner.
173185
for (int i = (int) n_main; i < n_layer; ++i) {
174-
load_block_mtp(i);
186+
load_block_mtp(i, /*is_first_mtp=*/ i == (int) n_main);
175187
}
176188
}
177189

@@ -362,13 +374,13 @@ llama_model_step35::graph_mtp::graph_mtp(const llama_model & model, const llm_gr
362374
: llm_graph_context(params) {
363375
GGML_ASSERT(hparams.nextn_predict_layers > 0 && "STEP35 MTP requires nextn_predict_layers > 0");
364376

365-
// Round-robin across MTP blocks at draft step boundaries. Matches vllm's
366-
// `current_step_idx = spec_step_idx % num_mtp_layers` (step3p5_mtp.py).
367-
// The first MTP block lives at layer index `n_main`; the speculative
368-
// driver bumps `cparams.mtp_step` between AR iterations.
369-
const int n_main = (int) hparams.n_layer - (int) hparams.nextn_predict_layers;
370-
const int step_offset = (int) (cparams.mtp_step % hparams.nextn_predict_layers);
371-
const int il = n_main + step_offset;
377+
// Single-block MTP only: always run the first trained MTP block (Qwen
378+
// MTP / vLLM single-MTP-layer style). Multi-block round-robin proved to
379+
// be a much deeper refactor than this PR justifies; the trailing MTP
380+
// blocks are loaded with TENSOR_NOT_REQUIRED so pruned GGUFs (with just
381+
// block 0) also work — see load_arch_tensors below and
382+
// scripts/prune_step35_extra_mtp.py.
383+
const int il = (int) hparams.n_layer - (int) hparams.nextn_predict_layers;
372384
const auto & layer = model.layers[il];
373385

374386
GGML_ASSERT(layer.nextn.eh_proj && "MTP block missing nextn.eh_proj");

0 commit comments

Comments
 (0)