Skip to content

Commit e344c4a

Browse files
committed
dflash: remove rebundant logic & correct bias naming
1 parent 85a0089 commit e344c4a

3 files changed

Lines changed: 12 additions & 53 deletions

File tree

examples/speculative-simple/speculative-simple.cpp

Lines changed: 0 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -232,21 +232,6 @@ int main(int argc, char ** argv) {
232232

233233
const auto t_dec_start = ggml_time_us();
234234

235-
// Hybrid targets (e.g. Qwen3.5) have recurrent layers that cannot be partially rolled back via seq_rm.
236-
// For them, snapshot the target state before verify and, on rejection, restore it and replay only the accepted tokens to ensure correctness
237-
// This is not efficient because the target model may run twice, but it is required in current llama.cpp design
238-
const bool use_state_snapshot = params.speculative.dflash && llama_model_is_hybrid(model_tgt);
239-
if (params.speculative.dflash) {
240-
LOG_INF("%s: DFlash target=%s, using %s rollback path\n", __func__,
241-
llama_model_is_hybrid(model_tgt) ? "hybrid" : "pure-attention",
242-
use_state_snapshot ? "snapshot+restore" : "seq_rm");
243-
}
244-
std::vector<uint8_t> state_snap;
245-
if (use_state_snapshot) {
246-
const size_t sz = llama_state_seq_get_size(ctx_tgt, 0);
247-
state_snap.resize(sz);
248-
}
249-
250235
while (true) {
251236
// generate or reuse draft tokens
252237
//
@@ -294,17 +279,6 @@ int main(int argc, char ** argv) {
294279

295280
GGML_ASSERT(n_draft > 0);
296281

297-
// snapshot target state for potential rollback (hybrid/recurrent targets only)
298-
const int n_past_before = n_past;
299-
const llama_token id_last_saved = id_last;
300-
if (use_state_snapshot) {
301-
const size_t sz = llama_state_seq_get_size(ctx_tgt, 0);
302-
if (sz > state_snap.size()) {
303-
state_snap.resize(sz);
304-
}
305-
llama_state_seq_get_data(ctx_tgt, state_snap.data(), sz, 0);
306-
}
307-
308282
// always have a token to evaluate from before - id_last
309283
common_batch_clear(batch_tgt);
310284
common_batch_add (batch_tgt, id_last, n_past++, { 0 }, true);
@@ -403,21 +377,6 @@ int main(int argc, char ** argv) {
403377
draft.clear();
404378

405379
{
406-
// const bool had_rejection = ids.size() < draft.size() + 1;
407-
408-
// if (use_state_snapshot && had_rejection) {
409-
// // Restore snapshot and replay the committed prefix (id_last + accepted drafts) so target state exactly
410-
// LOG_DBG("DFlash rollback: restore target state and replay %zu tokens\n", ids.size());
411-
// llama_state_seq_set_data(ctx_tgt, state_snap.data(), state_snap.size(), 0);
412-
// common_batch_clear(batch_tgt);
413-
// common_batch_add(batch_tgt, id_last_saved, n_past_before, { 0 }, true);
414-
// for (size_t i = 0; i + 1 < ids.size(); ++i) {
415-
// common_batch_add(batch_tgt, ids[i], n_past_before + 1 + i, { 0 }, true);
416-
// }
417-
// if (batch_tgt.n_tokens > 0) {
418-
// llama_decode(ctx_tgt, batch_tgt);
419-
// }
420-
// } else {
421380
LOG_DBG("clear kv cache from any extra tokens, n_past = %d\n", n_past);
422381

423382
llama_memory_seq_rm(llama_get_memory(ctx_tgt), 0, n_past, -1);

src/llama-model.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7382,10 +7382,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
73827382
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0);
73837383
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
73847384

7385-
layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd_head_k * n_head}, TENSOR_NOT_REQUIRED);
7386-
layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_k_gqa}, TENSOR_NOT_REQUIRED);
7387-
layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_v_gqa}, TENSOR_NOT_REQUIRED);
7388-
layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED);
7385+
layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd_head_k * n_head}, TENSOR_NOT_REQUIRED);
7386+
layer.wk_b = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_k_gqa}, TENSOR_NOT_REQUIRED);
7387+
layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_v_gqa}, TENSOR_NOT_REQUIRED);
7388+
layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED);
73897389

73907390
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0);
73917391
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0);

src/models/dflash.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -67,25 +67,25 @@ llm_build_dflash_decode::llm_build_dflash_decode(const llama_model & model, cons
6767

6868
// Q from noise only
6969
ggml_tensor * Qcur = build_lora_mm(layer.wq, noise_norm);
70-
if (layer.bq) { Qcur = ggml_add(ctx0, Qcur, layer.bq); }
70+
if (layer.wq_b) { Qcur = ggml_add(ctx0, Qcur, layer.wq_b); }
7171
cb(Qcur, "Qcur", il);
7272

7373
// K = concat(k_proj(target_ctx), k_proj(noise))
7474
ggml_tensor * K_tgt = build_lora_mm(layer.wk, target_ctx);
7575
ggml_tensor * K_noise = build_lora_mm(layer.wk, noise_norm);
76-
if (layer.bk) {
77-
K_tgt = ggml_add(ctx0, K_tgt, layer.bk);
78-
K_noise = ggml_add(ctx0, K_noise, layer.bk);
76+
if (layer.wk_b) {
77+
K_tgt = ggml_add(ctx0, K_tgt, layer.wk_b);
78+
K_noise = ggml_add(ctx0, K_noise, layer.wk_b);
7979
}
8080
ggml_tensor * Kcur = ggml_concat(ctx0, K_tgt, K_noise, 1);
8181
cb(Kcur, "Kcur", il);
8282

8383
// V = concat(v_proj(target_ctx), v_proj(noise))
8484
ggml_tensor * V_tgt = build_lora_mm(layer.wv, target_ctx);
8585
ggml_tensor * V_noise = build_lora_mm(layer.wv, noise_norm);
86-
if (layer.bv) {
87-
V_tgt = ggml_add(ctx0, V_tgt, layer.bv);
88-
V_noise = ggml_add(ctx0, V_noise, layer.bv);
86+
if (layer.wv_b) {
87+
V_tgt = ggml_add(ctx0, V_tgt, layer.wv_b);
88+
V_noise = ggml_add(ctx0, V_noise, layer.wv_b);
8989
}
9090
ggml_tensor * Vcur = ggml_concat(ctx0, V_tgt, V_noise, 1);
9191
cb(Vcur, "Vcur", il);
@@ -123,7 +123,7 @@ llm_build_dflash_decode::llm_build_dflash_decode(const llama_model & model, cons
123123
cb(cur, "kqv_out", il);
124124

125125
cur = build_lora_mm(layer.wo, cur);
126-
if (layer.bo) { cur = ggml_add(ctx0, cur, layer.bo); }
126+
if (layer.wo_b) { cur = ggml_add(ctx0, cur, layer.wo_b); }
127127
cur = ggml_add(ctx0, cur, inpL);
128128
cb(cur, "attn_res", il);
129129

0 commit comments

Comments
 (0)