Skip to content

Commit 3f26661

Browse files
committed
fix(mtp): pin capture mode to FULL_SEQ when MTP attached — eliminates pre_norm null
Pre-norm hidden capture warning ("hidden_at_pos_pre_norm returned null at base_pos=N") was firing every run and silently halving accept_rate. Root cause: orchestrator switched capture_mode_ to LAST_ROW_ONLY after prefill, which captures only the last candidate row. On partial-accept chain iters the runner asks for the hidden at the COMMITTED position (earlier than the last candidate) — out of range, returns null, MTP falls back to post-norm and crushes D>=2 accept. Correct-by-construction fix: - enable_hidden_seq_capture(true) now pins capture_mode_ = FULL_SEQ and sets capture_pinned_ flag - set_hidden_capture_mode / set_hidden_capture_scope become no-ops while pinned (MTP-bound targets cannot have capture demoted) - Orchestrator's post-prefill LAST_ROW_ONLY toggle removed (would be a no-op anyway, but cleaner intent) Measured (Claude Code 24K Heron, same hardware): accept_rate 0.41 -> 0.78 (model card ceiling 0.83 at gamma<=2) MTP decode 34.7 -> 49.7 tok/s (+43%) Zero "pre_norm returned null" warnings in log vs blog max 29.6 tok/s: +68%. vs DFlash on same config: +141%.
1 parent 7e30f8b commit 3f26661

2 files changed

Lines changed: 31 additions & 7 deletions

File tree

dflash/src/common/mtp_orchestrator.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,10 @@ GenerateResult warm_and_decode(ModelBackend * backend,
5151
const int prompt_len = (int)req.prompt.size();
5252
const int prefill_ubatch = env_int("DFLASH27B_PREFILL_UBATCH", kDefaultPrefillUbatch);
5353

54-
target->enable_hidden_seq_capture(true);
55-
target->set_hidden_capture_scope(DFlashTarget::VerifyCaptureScope::FULL_SEQ);
54+
// Capture state is owned by the target+MTP attachment, not the orchestrator.
55+
// MTP's attach() already enabled+pinned FULL_SEQ; calling here would be a
56+
// no-op and an architectural smell (orchestrator reaching into target state).
57+
target->enable_hidden_seq_capture(true); // idempotent for MTP-bound target
5658

5759
std::vector<float> all_prefill_hidden((size_t)prompt_len * hidden);
5860
int32_t last_tok = -1;
@@ -81,7 +83,9 @@ GenerateResult warm_and_decode(ModelBackend * backend,
8183
result.prefill_s = std::chrono::duration<double>(
8284
std::chrono::steady_clock::now() - t_prefill0).count();
8385

84-
target->set_hidden_capture_scope(DFlashTarget::VerifyCaptureScope::LAST_ROW_ONLY);
86+
// No scope toggle here: MTP-pinned target stays FULL_SEQ for the chain's
87+
// whole lifetime (partial-accept iters need the COMMITTED row, not the
88+
// last-candidate row, so LAST_ROW_ONLY would silently return null).
8589

8690
if (last_tok < 0) {
8791
result.error = "warm_and_decode: prefill produced invalid argmax";

dflash/src/qwen35/qwen35_dflash_target.h

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,19 @@ class Qwen35DFlashTarget : public DFlashTarget {
4444

4545
// Enable per-position post-norm hidden capture during verify_batch.
4646
// Off by default; MTP modules that depend on hidden_at_pos() flip it on
47-
// in attach(). Non-MTP paths (target_gen, DFlash drafter spec-decode)
48-
// leave it off and avoid pinning the full [n_embd, n_tokens] tensor.
49-
void enable_hidden_seq_capture(bool on) override { capture_hidden_seq_ = on; }
47+
// in attach() — and ALSO pin capture_mode_ to FULL_SEQ so the runtime
48+
// toggle below cannot demote it to LAST_ROW_ONLY (which captures the
49+
// wrong row for partial-accept iterations of the MTP chain and silently
50+
// returns null from hidden_at_pos_pre_norm).
51+
void enable_hidden_seq_capture(bool on) override {
52+
capture_hidden_seq_ = on;
53+
if (on) {
54+
capture_mode_ = VerifyCaptureMode::FULL_SEQ;
55+
capture_pinned_ = true;
56+
} else {
57+
capture_pinned_ = false;
58+
}
59+
}
5060

5161
// Hidden-sequence capture granularity. FULL_SEQ downloads the entire
5262
// [n_tokens, n_embd] post-norm + pre-norm hidden tensors device->host
@@ -62,13 +72,17 @@ class Qwen35DFlashTarget : public DFlashTarget {
6272
FULL_SEQ, // default — required during prefill / warm_head_kv
6373
LAST_ROW_ONLY, // decode mode — only hidden_at_pos(base_pos-1) used
6474
};
75+
// Runtime toggle is a NO-OP once MTP has pinned capture to FULL_SEQ. This
76+
// makes the partial-accept-returns-null bug impossible by construction —
77+
// non-MTP callers can still use this freely; MTP callers cannot demote it.
6578
void set_hidden_capture_mode(VerifyCaptureMode mode) {
79+
if (capture_pinned_) return;
6680
capture_mode_ = mode;
6781
}
6882
VerifyCaptureMode hidden_capture_mode() const { return capture_mode_; }
6983

70-
// DFlashTarget abstract bridge.
7184
void set_hidden_capture_scope(DFlashTarget::VerifyCaptureScope scope) override {
85+
if (capture_pinned_) return;
7286
capture_mode_ = (scope == DFlashTarget::VerifyCaptureScope::LAST_ROW_ONLY)
7387
? VerifyCaptureMode::LAST_ROW_ONLY
7488
: VerifyCaptureMode::FULL_SEQ;
@@ -216,6 +230,12 @@ class Qwen35DFlashTarget : public DFlashTarget {
216230
// Toggled by set_hidden_capture_mode(). See VerifyCaptureMode docs.
217231
VerifyCaptureMode capture_mode_ = VerifyCaptureMode::FULL_SEQ;
218232

233+
// Pinned by enable_hidden_seq_capture(true) — once an MTP module attaches
234+
// and enables capture, the mode cannot be demoted to LAST_ROW_ONLY (that
235+
// captures the wrong row for partial-accept chain iters and silently
236+
// returns null from hidden_at_pos_pre_norm).
237+
bool capture_pinned_ = false;
238+
219239
#ifdef DFLASH_VERIFY_PROFILE
220240
// Per-instance accumulators: summed wall-clock (ms) per verify_batch call;
221241
// dumped from destructor. Zero-cost when flag is off.

0 commit comments

Comments
 (0)