Skip to content

Commit 129dbe7

Browse files
committed
feat(mtp): MTP-via-daemon — Qwen3.6 MTP head speculation through server.py
End-to-end Multi-Token-Prediction speculation through the HTTP daemon, unlocking MTP for real coding agents. Includes all PR #213 daemon fixes plus the new MTP pipeline. Headline: Claude Code on a 24K-token system prompt decodes at 35.8 tok/s via MTP vs 22.0 via DFlash (+63%). Mean across 7 harness clients: 29.3 tok/s MTP vs 22.0 DFlash (+33%). At/above the top of Lucebox's published blog range (22.6-29.6 tok/s). See PR description for full details and validation. Squashes: fix(qwen35,server): make HTTP daemon path generate real content fix(qwen35): restore_and_generate actually skips cached prefix fix(qwen35): reset cache state between bare-prompt requests fix(qwen35): populate result.prefill_s/decode_s feat(mtp): wire MTP through daemon (5-phase implementation) fix(mtp): n_ctx env override for large agent prompts
1 parent 0e0c93f commit 129dbe7

8 files changed

Lines changed: 3324 additions & 60 deletions

File tree

dflash/scripts/server.py

Lines changed: 75 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -696,7 +696,11 @@ def build_app(target: Path, draft: Path | None, bin_path: Path, budget: int, max
696696
verify_mode: str = "ddtree",
697697
extra_daemon_args: list[str] | None = None,
698698
lazy_draft: bool = False,
699-
verbose_daemon: bool = False) -> FastAPI:
699+
verbose_daemon: bool = False,
700+
mtp_gguf: Path | None = None,
701+
mtp_gamma: int = 3,
702+
mtp_draft_source: str = "chain",
703+
mtp_draft_topk: int = 1) -> FastAPI:
700704
import asyncio
701705
if _extra_daemon_has_target_sharding(extra_daemon_args):
702706
if prefix_cache_slots > 0 or prefill_cache_slots > 0:
@@ -753,6 +757,19 @@ async def _openai_compat_error_handler(_request: Request, exc: OpenAICompatError
753757
cmd = [bin_abs, str(target), "--daemon",
754758
f"--max-ctx={max_ctx}",
755759
f"--stream-fd={stream_fd_val}"]
760+
elif mtp_gguf is not None:
761+
# MTP mode: no --draft (MTP head lives inside target or mtp_gguf),
762+
# no DFlash flags. Daemon dispatches to MTP code path via --mtp-gguf.
763+
cmd = [bin_abs, str(target), "--daemon",
764+
f"--max-ctx={max_ctx}",
765+
f"--stream-fd={stream_fd_val}",
766+
f"--mtp-gguf={mtp_gguf}",
767+
f"--gamma={mtp_gamma}",
768+
"--draft-source", mtp_draft_source]
769+
if mtp_draft_source == "mtp_topk":
770+
cmd.append(f"--draft-topk={mtp_draft_topk}")
771+
if extra_daemon_args:
772+
cmd.extend(extra_daemon_args)
756773
else:
757774
if draft is None:
758775
raise SystemExit("qwen35 arch requires --draft <draft.gguf|model.safetensors>")
@@ -999,6 +1016,8 @@ def _maybe_compress(msgs: list[dict], prompt_bin: Path, prompt_ids: list[int],
9991016
pass
10001017
return new_bin, new_ids
10011018

1019+
_vocab_size: int = getattr(tokenizer, "vocab_size", 0) or 0
1020+
10021021
def _token_stream(r, n_gen, timing=None):
10031022
generated = 0
10041023
hit_stop = False
@@ -1011,6 +1030,8 @@ def _token_stream(r, n_gen, timing=None):
10111030
if timing is not None:
10121031
timing["daemon_done"] = True
10131032
break
1033+
if _vocab_size and not (0 <= tok_id < _vocab_size):
1034+
continue
10141035
if timing and timing.get("t_first_tok") is None:
10151036
timing["t_first_tok"] = time.monotonic()
10161037
if hit_stop:
@@ -1048,6 +1069,8 @@ async def _astream_tokens(r, n_gen, timing=None):
10481069
if timing is not None:
10491070
timing["daemon_done"] = True
10501071
break
1072+
if _vocab_size and not (0 <= tok_id < _vocab_size):
1073+
continue
10511074
if timing and timing.get("t_first_tok") is None:
10521075
timing["t_first_tok"] = time.monotonic()
10531076
if hit_stop:
@@ -1413,9 +1436,15 @@ def emit_delta(text, kind):
14131436
accumulated_content += pre
14141437
out = emit_delta(pre, "content")
14151438
if out: yield out
1416-
if which == "think":
1439+
if which == "think" and _thinking_enabled(req.chat_template_kwargs):
14171440
window = window[idx + len(THINK_OPEN_TAG):]
14181441
mode = "reasoning"
1442+
elif which == "think":
1443+
# thinking disabled — keep tag in content
1444+
accumulated_content += THINK_OPEN_TAG
1445+
out = emit_delta(THINK_OPEN_TAG, "content")
1446+
if out: yield out
1447+
window = window[idx + len(THINK_OPEN_TAG):]
14191448
elif which == "think_close":
14201449
window = window[idx + len(THINK_CLOSE_TAG):]
14211450
else:
@@ -1594,10 +1623,11 @@ def emit_delta(text, kind):
15941623
i = first_stop_match(text, stops)
15951624
if i != -1:
15961625
text = text[:i]
1597-
# Parse reasoning and tool calls
1598-
thinking_enabled = True
1599-
if req.chat_template_kwargs:
1600-
thinking_enabled = req.chat_template_kwargs.get("enable_thinking", True)
1626+
# Parse reasoning and tool calls. Match the prompt-rendering default
1627+
# (enable_thinking=False) so that spontaneous <think> tags from Qwen3.6
1628+
# are kept in content instead of stripped into an empty message when
1629+
# the model runs out of tokens before emitting </think>.
1630+
thinking_enabled = _thinking_enabled(req.chat_template_kwargs)
16011631
cleaned, tool_calls = parse_tool_calls(text, tools=req.tools)
16021632
_remember_tool_call_text(text, tool_calls)
16031633
cleaned, reasoning = parse_reasoning(
@@ -2230,9 +2260,7 @@ async def _responses_non_stream(
22302260
except Exception: pass
22312261

22322262
text = tokenizer.decode(tokens, skip_special_tokens=True)
2233-
thinking_enabled = True
2234-
if chat_req.chat_template_kwargs:
2235-
thinking_enabled = chat_req.chat_template_kwargs.get("enable_thinking", True)
2263+
thinking_enabled = _thinking_enabled(chat_req.chat_template_kwargs)
22362264
cleaned, tool_calls = parse_tool_calls(text, tools=chat_req.tools)
22372265
_remember_tool_call_text(text, tool_calls)
22382266
cleaned, reasoning = parse_reasoning(
@@ -2420,9 +2448,16 @@ async def sse() -> AsyncIterator[str]:
24202448
yield _resp_sse("response.output_text.delta", {
24212449
"item_id": msg_item_id, "output_index": 0,
24222450
"content_index": 0, "delta": pre})
2423-
if which == "think":
2451+
if which == "think" and _thinking_enabled(chat_req.chat_template_kwargs):
24242452
window = window[idx + len(THINK_OPEN_TAG):]
24252453
mode = "reasoning"
2454+
elif which == "think":
2455+
# thinking disabled — keep tag in content
2456+
accumulated_text += THINK_OPEN_TAG
2457+
yield _resp_sse("response.output_text.delta", {
2458+
"item_id": msg_item_id, "output_index": 0,
2459+
"content_index": 0, "delta": THINK_OPEN_TAG})
2460+
window = window[idx + len(THINK_OPEN_TAG):]
24262461
elif which == "think_close":
24272462
window = window[idx + len(THINK_CLOSE_TAG):]
24282463
else:
@@ -2650,6 +2685,20 @@ def main():
26502685
help="Pass --draft-feature-mirror to test_dflash (safe cross-GPU feature path)")
26512686
ap.add_argument("--peer-access", action="store_true",
26522687
help="Pass --peer-access to test_dflash (prefer P2P memcpy when available)")
2688+
# ── MTP (Multi-Token Prediction) speculator ──────────────────────────────
2689+
# When --mtp-gguf is set, the daemon runs MTP-head speculation instead of
2690+
# DFlash+DDTree. --draft is ignored (the MTP head is in the same GGUF as
2691+
# target, or a separate fused GGUF). Prefix-cache slots are auto-disabled
2692+
# in MTP mode because RESTORE does not snapshot MTP head KV yet.
2693+
ap.add_argument("--mtp-gguf", type=Path, default=None,
2694+
help="Path to MTP-fused GGUF. When set, daemon runs MTP "
2695+
"speculation; --draft and DFlash flags are ignored.")
2696+
ap.add_argument("--mtp-gamma", type=int, default=3,
2697+
help="MTP chain depth (default 3; recommended D=3 per matrix bench)")
2698+
ap.add_argument("--mtp-draft-source", choices=["chain", "mtp_topk"], default="chain",
2699+
help="MTP draft generation strategy (default chain)")
2700+
ap.add_argument("--mtp-draft-topk", type=int, default=1,
2701+
help="Top-K for mtp_topk draft source (default 1, ignored for chain)")
26532702
add_cli_flags(ap)
26542703
args = ap.parse_args()
26552704
prefill_cfg = config_from_args(args)
@@ -2695,6 +2744,17 @@ def main():
26952744
# through the laguna daemon now, so --prefill-compression and
26962745
# --prefix-cache-slots behave the same as on the qwen35 path.
26972746
draft = None
2747+
elif args.mtp_gguf is not None:
2748+
# MTP mode: --draft is ignored; MTP head lives in the target (or in --mtp-gguf
2749+
# if separate). Force prefix/prefill cache off — RESTORE doesn't snapshot
2750+
# MTP head KV yet (planned for a follow-up PR).
2751+
if not args.mtp_gguf.is_file():
2752+
raise SystemExit(f"--mtp-gguf not found at {args.mtp_gguf}")
2753+
draft = None
2754+
if args.prefix_cache_slots > 0 or args.prefill_cache_slots > 0:
2755+
print(" [cfg] MTP mode: disabling prefix/prefill cache (MTP head KV snapshot not implemented)")
2756+
args.prefix_cache_slots = 0
2757+
args.prefill_cache_slots = 0
26982758
else:
26992759
draft = resolve_draft(args.draft) if args.draft.is_dir() else args.draft
27002760
if not draft.is_file():
@@ -2726,7 +2786,11 @@ def main():
27262786
verify_mode=args.verify_mode,
27272787
extra_daemon_args=placement.daemon_args or None,
27282788
lazy_draft=args.lazy_draft,
2729-
verbose_daemon=args.verbose_daemon)
2789+
verbose_daemon=args.verbose_daemon,
2790+
mtp_gguf=args.mtp_gguf,
2791+
mtp_gamma=args.mtp_gamma,
2792+
mtp_draft_source=args.mtp_draft_source,
2793+
mtp_draft_topk=args.mtp_draft_topk)
27302794

27312795
import uvicorn
27322796
logging.basicConfig(

dflash/scripts/test_server.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def mock_tokenizer():
2323
tokenizer.encode.return_value = [1]
2424
tokenizer.decode.return_value = "hello"
2525
tokenizer.apply_chat_template.return_value = "prompt"
26+
tokenizer.vocab_size = 151936
2627
return tokenizer
2728

2829

@@ -939,3 +940,60 @@ def test_responses_instructions_and_developer_merged(mock_os_read, mock_pipe,
939940
assert len(system_msgs) == 1
940941
assert "Top-level instructions." in system_msgs[0]["content"]
941942
assert "Developer context." in system_msgs[0]["content"]
943+
944+
945+
# ─── out-of-range token filtering (OverflowError regression) ───────
946+
947+
@patch("server.os.pipe")
948+
@patch("server.os.read")
949+
def test_out_of_range_token_non_streaming_returns_200(
950+
mock_os_read, mock_pipe, mock_tokenizer, app):
951+
"""Daemon emits a negative sentinel-like token (-2) that is not the EOS
952+
sentinel (-1). Without filtering, tokenizer.decode([-2]) raises
953+
OverflowError → 500. After the fix the token is silently dropped and
954+
the endpoint returns 200 with empty content rather than crashing."""
955+
mock_pipe.return_value = (1, 2)
956+
# Make decode raise for any negative token to mirror HF tokenizer behaviour
957+
def _decode(ids, **_kw):
958+
if any(t < 0 or t >= 151936 for t in ids):
959+
raise OverflowError("out of range integral type conversion attempted")
960+
return "hello"
961+
mock_tokenizer.decode.side_effect = _decode
962+
# Daemon stream: bogus token (-2) then EOS sentinel (-1)
963+
mock_os_read.side_effect = [struct.pack("<i", -2), struct.pack("<i", -1)]
964+
965+
client = TestClient(app)
966+
response = client.post("/v1/chat/completions", json={
967+
"model": MODEL_NAME,
968+
"messages": [{"role": "user", "content": "hi"}],
969+
"stream": False,
970+
})
971+
972+
assert response.status_code == 200
973+
data = response.json()
974+
assert "choices" in data
975+
assert data["choices"][0]["finish_reason"] == "stop"
976+
977+
978+
@patch("server.os.pipe")
979+
@patch("server.os.read")
980+
def test_out_of_range_token_streaming_returns_200(
981+
mock_os_read, mock_pipe, mock_tokenizer, app):
982+
"""Same contract for the streaming path: bad token is dropped, no crash."""
983+
mock_pipe.return_value = (1, 2)
984+
def _decode(ids, **_kw):
985+
if any(t < 0 or t >= 151936 for t in ids):
986+
raise OverflowError("out of range integral type conversion attempted")
987+
return ""
988+
mock_tokenizer.decode.side_effect = _decode
989+
mock_os_read.side_effect = [struct.pack("<i", -2), struct.pack("<i", -1)]
990+
991+
client = TestClient(app)
992+
response = client.post("/v1/chat/completions", json={
993+
"model": MODEL_NAME,
994+
"messages": [{"role": "user", "content": "hi"}],
995+
"stream": True,
996+
})
997+
998+
assert response.status_code == 200
999+
assert "data: [DONE]" in response.text

0 commit comments

Comments
 (0)