Skip to content

Commit d7bbcaf

Browse files
fix(mac-chat): stream tokens live so the CLI no longer looks frozen on long answers (#152)
Root cause (from the code): the interactive chat REPL is fully NON-streaming — _gen_turn runs the entire generation (up to max_new_tokens) before printing anything. On a code-gen prompt the answer is long and the f_θ path is slow (~3-5 tok/s, single-token past the wrap), so the terminal stays silent for minutes — indistinguishable from a freeze (user: '一进入就卡死,完全没有输出'). Not a deadlock (prior scripted code-gen runs completed). Fix: add an on_commit streaming callback to the 3 fused generate loops (safe _emit wrapper, never breaks decode) and have the chat REPL decode incrementally and print the delta LIVE (+ a per-block '[stream] blk=.. t=..s' stderr line that also proves the engine is progressing, not hung). New mlx-kakeya-chat-stream- probe preset runs the user's exact prompt to validate. Co-authored-by: Cursor Agent <cursoragent@cursor.com> Co-authored-by: FluffyAIcode <FluffyAIcode@users.noreply.github.com>
1 parent f13594d commit d7bbcaf

4 files changed

Lines changed: 88 additions & 7 deletions

File tree

inference_engine/backends/mlx/fused_specdecode.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,18 @@ def make_full_kv_prompt_cache(mlx_model: Any) -> List[Any]:
363363
return [KVCache() for _ in range(n)]
364364

365365

366+
def _emit(on_commit: Optional[Callable[[List[int]], None]],
367+
generated: List[int]) -> None:
368+
"""Invoke a streaming callback with the tokens committed so far, swallowing
369+
any error so token streaming can never break generation."""
370+
if on_commit is None:
371+
return
372+
try:
373+
on_commit(list(generated))
374+
except Exception: # pragma: no cover - streaming must never break decode
375+
pass
376+
377+
366378
def fused_specdecode_generate_mlx_trim(
367379
adapter: "MLXRestoredIncrementalVerifier",
368380
drafter: Any,
@@ -374,6 +386,7 @@ def fused_specdecode_generate_mlx_trim(
374386
block_size: int,
375387
eos_ids: Sequence[int] = (),
376388
single_fused: bool = False,
389+
on_commit: Optional[Callable[[List[int]], None]] = None,
377390
) -> Dict[str, Any]:
378391
"""CUDA-parity fused spec decode: KEEP accepted K/V, TRIM only the rejected
379392
tail (no rollback, no carry re-forward). Requires the adapter to be
@@ -440,6 +453,7 @@ def fused_specdecode_generate_mlx_trim(
440453
commit = check[:accepted]
441454
generated += commit
442455
accepts.append(accepted)
456+
_emit(on_commit, generated)
443457
adapter.next_token_logits = next_row
444458
aux_rows = adapter._last_aux_mx
445459
# KEEP accepted (positions base..base+accepted-1), TRIM rejected.
@@ -490,6 +504,7 @@ def fused_specdecode_generate_mlx(
490504
gen_tokens: int,
491505
block_size: int,
492506
eos_ids: Sequence[int] = (),
507+
on_commit: Optional[Callable[[List[int]], None]] = None,
493508
) -> Dict[str, Any]:
494509
"""All-MLX fused spec decode with ONE host sync per block.
495510
@@ -587,6 +602,7 @@ def fused_specdecode_generate_mlx(
587602
commit = check[:accepted]
588603
generated += commit
589604
accepts.append(accepted)
605+
_emit(on_commit, generated)
590606
tail_logits = next_row
591607
adapter.next_token_logits = next_row
592608
aux_rows = adapter._last_aux_mx # rows for positions base_fwd..base_fwd+k+L
@@ -672,6 +688,7 @@ def fused_specdecode_generate(
672688
arange_fn: Callable[[int, int], Any],
673689
cat_aux_fn: Callable[[Sequence[Any]], Any],
674690
allow_greedy_fallback: bool = True,
691+
on_commit: Optional[Callable[[List[int]], None]] = None,
675692
) -> Dict[str, Any]:
676693
"""Run the fused engine. ``adapter`` must already be prefilled. Per block:
677694
draft from the cached drafter context (B), verify+capture-aux incrementally
@@ -772,6 +789,7 @@ def fused_specdecode_generate(
772789
commit = candidate[:accepted] + [correction]
773790
generated += commit
774791
accepts.append(accepted)
792+
_emit(on_commit, generated)
775793
if any(t in eos for t in commit):
776794
break
777795
if (allow_greedy_fallback and len(accepts) >= 2
@@ -789,6 +807,7 @@ def fused_specdecode_generate(
789807
tok = int(argmax_fn(adapter.next_token_logits))
790808
adapter.append_token(tok)
791809
generated.append(tok)
810+
_emit(on_commit, generated)
792811
if tok in eos:
793812
break
794813
timing["fallback_greedy_s"] += time.perf_counter() - t_fb

inference_engine/bridge/manifest.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -749,6 +749,32 @@ def _harness_preset(
749749
},
750750
validate_reports=True, # §4 liveness gate: asserts f_theta_ran on-device
751751
),
752+
Preset(
753+
name="mlx-kakeya-chat-stream-probe",
754+
description="Reproduce + validate the 'CLI looks frozen on a code "
755+
"prompt' report: full f_θ chat on the user's exact prompt "
756+
"(根据pow的机制,给出完整的c代码实现). With token streaming the log "
757+
"shows incremental '[stream] blk=.. t=..s' lines as tokens "
758+
"commit (proving the engine is generating, not deadlocked) "
759+
"and the answer text builds up over time rather than after "
760+
"a long silence.",
761+
command_templates=(
762+
(
763+
"python3", "scripts/research/k3_integrated_niah_eval_mac.py",
764+
"--verifier-path", "${ENV:KAKEYA_MAC_VERIFIER_PATH}",
765+
"--drafter-id", "${ENV:KAKEYA_MAC_DRAFTER_ID}",
766+
"--f-theta-dir", "${ENV:KAKEYA_MAC_FTHETA_DIR}",
767+
"--s5-exact-full-attn", "--fused-specdecode", "--force-f-theta",
768+
"--sink-size", "4", "--window-size", "64", "--block-size", "4",
769+
"--max-new-tokens", "{max_new_tokens}", "--ignore-turn-stop",
770+
"--chat", "--chat-scripted", "根据pow的机制,给出完整的c代码实现",
771+
"--output", "results/research/chat_stream_probe_2815.json",
772+
),
773+
),
774+
timeout_minutes=90,
775+
params={"max_new_tokens": ("int:max_new_tokens", "200")},
776+
validate_reports=False,
777+
),
752778
Preset(
753779
name="mlx-kakeya-launcher-smoke",
754780
description="Verify the one-command local launcher "

scripts/research/k3_integrated_niah_eval_mac.py

Lines changed: 42 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -769,7 +769,7 @@ def _encode_chat(history: List[Dict[str, str]]) -> List[int]:
769769
history, add_generation_prompt=True)
770770
return list(cids.tolist() if hasattr(cids, "tolist") else cids)
771771

772-
def _gen_turn(pid: List[int]) -> Dict[str, Any]:
772+
def _gen_turn(pid: List[int], on_commit=None) -> Dict[str, Any]:
773773
# Opt-in A/B control (--chat-native-ref): a plain NATIVE greedy
774774
# AR decode of the SAME prompt for --max-new-tokens. Captured as
775775
# res["native_ref_text"] so the fused answer can be compared
@@ -815,20 +815,22 @@ def _gen_turn(pid: List[int]) -> Dict[str, Any]:
815815
adapter, active_drafter, aux_prompt=aux_prompt,
816816
embed_fn=embed_fn, lm_head_fn=lm_head_fn,
817817
gen_tokens=args.max_new_tokens, block_size=args.block_size,
818-
eos_ids=chat_eos, single_fused=args.single_fused)
818+
eos_ids=chat_eos, single_fused=args.single_fused,
819+
on_commit=on_commit)
819820
elif mlx_drafter is not None:
820821
res = fused_specdecode_generate_mlx(
821822
adapter, active_drafter, aux_prompt=aux_prompt,
822823
embed_fn=embed_fn, lm_head_fn=lm_head_fn,
823824
gen_tokens=args.max_new_tokens, block_size=args.block_size,
824-
eos_ids=chat_eos)
825+
eos_ids=chat_eos, on_commit=on_commit)
825826
else:
826827
res = fused_specdecode_generate(
827828
adapter, active_drafter, aux_prompt=aux_prompt,
828829
embed_fn=embed_fn, lm_head_fn=lm_head_fn,
829830
gen_tokens=args.max_new_tokens, block_size=args.block_size,
830831
eos_ids=chat_eos, argmax_fn=argmax_fn, arange_fn=arange_fn,
831-
cat_aux_fn=cat_aux_fn, allow_greedy_fallback=False)
832+
cat_aux_fn=cat_aux_fn, allow_greedy_fallback=False,
833+
on_commit=on_commit)
832834
res["decode_s"] = round(time.perf_counter() - t0, 3)
833835
res["f_theta_ran"] = f_theta_ran
834836
res["f_theta_layers"] = sorted(rk.keys()) if rk else []
@@ -853,6 +855,32 @@ def _gen_turn(pid: List[int]) -> Dict[str, Any]:
853855
sum(int(getattr(c, "nbytes", 0)) for c in (adapter._cache or [])))
854856
return res
855857

858+
def _make_stream_cb(to_stdout: bool):
859+
"""on_commit callback: incrementally decode the committed tokens
860+
and (interactive) print the new delta to stdout so the user sees
861+
the answer build LIVE instead of waiting for the whole generation.
862+
Always logs a per-block timing line to stderr (proves streaming /
863+
rules out a hang)."""
864+
st = {"chars": 0, "blk": 0, "t0": time.perf_counter()}
865+
866+
def cb(toks: List[int]) -> None:
867+
st["blk"] += 1
868+
try:
869+
txt = tokenizer.decode(toks, skip_special_tokens=True)
870+
except TypeError:
871+
txt = tokenizer.decode(toks)
872+
if to_stdout:
873+
delta = txt[st["chars"]:]
874+
if delta:
875+
sys.stdout.write(delta)
876+
sys.stdout.flush()
877+
st["chars"] = len(txt)
878+
sys.stderr.write(
879+
f"[stream] blk={st['blk']} tok={len(toks)} "
880+
f"t={time.perf_counter() - st['t0']:.1f}s\n")
881+
sys.stderr.flush()
882+
return cb
883+
856884
print(f"[chat] FULL fused engine: verifier={args.verifier_path} "
857885
f"drafter={args.drafter_id} f_theta={args.f_theta_dir} "
858886
f"S5 sink={args.sink_size} window={args.window_size} "
@@ -865,7 +893,8 @@ def _gen_turn(pid: List[int]) -> Dict[str, Any]:
865893
transcript = []
866894
for u in turns:
867895
history.append({"role": "user", "content": u})
868-
res = _gen_turn(_encode_chat(history))
896+
res = _gen_turn(_encode_chat(history),
897+
on_commit=_make_stream_cb(to_stdout=False))
869898
history.append({"role": "assistant", "content": res["text"]})
870899
tps = (res["decode_tokens"] / res["decode_s"]
871900
if res["decode_s"] > 0 else 0.0)
@@ -926,11 +955,17 @@ def _gen_turn(pid: List[int]) -> Dict[str, Any]:
926955
if not u:
927956
break
928957
history.append({"role": "user", "content": u})
929-
res = _gen_turn(_encode_chat(history))
958+
# Stream the answer LIVE so the terminal shows progress as tokens
959+
# commit (the f_θ path is slow; without this the CLI looks frozen
960+
# for minutes on long answers like code generation).
961+
sys.stdout.write("gemma-4> ")
962+
sys.stdout.flush()
963+
res = _gen_turn(_encode_chat(history),
964+
on_commit=_make_stream_cb(to_stdout=True))
930965
history.append({"role": "assistant", "content": res["text"]})
931966
tps = (res["decode_tokens"] / res["decode_s"]
932967
if res["decode_s"] > 0 else 0.0)
933-
sys.stdout.write("gemma-4> " + res["text"] + "\n")
968+
sys.stdout.write("\n")
934969
sys.stdout.flush()
935970
print(f"[chat] blocks={res['blocks']} accept_len="
936971
f"{res['mean_accept_len']} {round(tps,2)} tok/s "

tests/inference_engine/bridge/test_manifest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ def test_allowlist_contains_exactly_the_documented_presets():
8181
"mlx-batched-pad-decode",
8282
"mlx-env-probe",
8383
"mlx-kakeya-chat-smoke",
84+
"mlx-kakeya-chat-stream-probe",
8485
"mlx-kakeya-degen-probe",
8586
"mlx-kakeya-fused-chat-ftheta",
8687
"mlx-kakeya-fused-chat-smoke",

0 commit comments

Comments
 (0)