Skip to content

Commit 5c1bc29

Browse files
Merge remote-tracking branch 'origin/AgentMemory/fused-codegen-degeneration-fix-2815' into _train2
# Conflicts: # inference_engine/backends/mlx/fused_specdecode.py # inference_engine/bridge/manifest.py # scripts/research/k3_integrated_niah_eval_mac.py # tests/inference_engine/bridge/test_manifest.py Co-authored-by: FluffyAIcode <FluffyAIcode@users.noreply.github.com>
2 parents bc74bf9 + 772c8df commit 5c1bc29

5 files changed

Lines changed: 213 additions & 6 deletions

File tree

inference_engine/backends/mlx/fused_specdecode.py

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
restored_prefill_cache,
3636
)
3737

38-
3938
# --------------------------------------------------------------------------- #
4039
# Component A: capture verifier aux-layer hidden states (no transformers
4140
# `output_hidden_states` on MLX → patch the decoder-layer __call__).
@@ -387,6 +386,7 @@ def fused_specdecode_generate_mlx_trim(
387386
eos_ids: Sequence[int] = (),
388387
single_fused: bool = False,
389388
on_commit: Optional[Callable[[List[int]], None]] = None,
389+
stop_on_runaway: bool = True,
390390
) -> Dict[str, Any]:
391391
"""CUDA-parity fused spec decode: KEEP accepted K/V, TRIM only the rejected
392392
tail (no rollback, no carry re-forward). Requires the adapter to be
@@ -412,6 +412,7 @@ def fused_specdecode_generate_mlx_trim(
412412
generated: List[int] = []
413413
accepts: List[int] = []
414414
block_evals: List[float] = []
415+
stopped_on_runaway = False
415416
ctx_len = C
416417
try:
417418
while len(generated) < gen_tokens:
@@ -474,6 +475,12 @@ def fused_specdecode_generate_mlx_trim(
474475
timing["extend_s"] += time.perf_counter() - t_extend
475476
if any(t in eos for t in commit):
476477
break
478+
if stop_on_runaway:
479+
drop = _trailing_runaway_drop(generated)
480+
if drop > 0:
481+
del generated[len(generated) - drop:]
482+
stopped_on_runaway = True
483+
break
477484
finally:
478485
adapter._capture_aux = False
479486
generated = generated[:gen_tokens]
@@ -483,6 +490,7 @@ def fused_specdecode_generate_mlx_trim(
483490
"mean_accept_len": (round(sum(accepts) / len(accepts), 3)
484491
if accepts else 0.0),
485492
"decode_tokens": len(generated),
493+
"stopped_on_runaway": stopped_on_runaway,
486494
"loop": ("mlx_trim_single_fused_probe" if single_fused
487495
else "mlx_trim_keep_accepted_cuda_parity"),
488496
"single_fused": bool(single_fused),
@@ -505,6 +513,7 @@ def fused_specdecode_generate_mlx(
505513
block_size: int,
506514
eos_ids: Sequence[int] = (),
507515
on_commit: Optional[Callable[[List[int]], None]] = None,
516+
stop_on_runaway: bool = True,
508517
) -> Dict[str, Any]:
509518
"""All-MLX fused spec decode with ONE host sync per block.
510519
@@ -546,6 +555,7 @@ def fused_specdecode_generate_mlx(
546555

547556
generated: List[int] = []
548557
accepts: List[int] = []
558+
stopped_on_runaway = False
549559
# Rollback-carry state: rejected blocks roll the WHOLE forward back
550560
# (rollback_block — see its docstring for why trim is unsound on the
551561
# wrapped sliding ring) and carry the stream-committed-but-not-cached
@@ -630,6 +640,12 @@ def fused_specdecode_generate_mlx(
630640
timing["extend_s"] += time.perf_counter() - t_extend
631641
if any(t in eos for t in commit):
632642
break
643+
if stop_on_runaway:
644+
drop = _trailing_runaway_drop(generated)
645+
if drop > 0:
646+
del generated[len(generated) - drop:]
647+
stopped_on_runaway = True
648+
break
633649
finally:
634650
adapter._capture_aux = False
635651
generated = generated[:gen_tokens]
@@ -639,6 +655,7 @@ def fused_specdecode_generate_mlx(
639655
"mean_accept_len": (round(sum(accepts) / len(accepts), 3)
640656
if accepts else 0.0),
641657
"decode_tokens": len(generated),
658+
"stopped_on_runaway": stopped_on_runaway,
642659
"loop": "mlx_rollback_carry_v3",
643660
"time_breakdown_s": {k: round(v, 3) for k, v in timing.items()},
644661
}
@@ -671,6 +688,40 @@ def _sliding_ring_would_wrap(cache: Any, n_new: int) -> bool:
671688
return False
672689

673690

691+
def _trailing_runaway_drop(
692+
ids: Sequence[int],
693+
*,
694+
max_period: int = 8,
695+
min_reps: int = 12,
696+
keep_reps: int = 3,
697+
) -> int:
698+
"""Return how many TRAILING tokens to drop if ``ids`` ends in a runaway
699+
short-period loop, else 0.
700+
701+
A runaway loop is a unit of ``1..max_period`` tokens repeated ``>= min_reps``
702+
times back-to-back at the tail (e.g. the ``**``/``.2``/``*`` markdown-marker
703+
collapse greedy decoding falls into on code prompts). When found, we keep
704+
``keep_reps`` instances and drop the rest, so callers can stop generation
705+
with a clean tail instead of emitting an unbounded wall of repeats.
706+
707+
Deliberately CONSERVATIVE (>= 12 back-to-back repeats of a <= 8-token unit)
708+
so legitimately repetitive text — numbered lists, ``矿工 A/B/C`` enumerations,
709+
structured code — is never trimmed. Returns 0 when no runaway is present."""
710+
n = len(ids)
711+
for p in range(1, max_period + 1):
712+
if n < p * min_reps:
713+
continue
714+
unit = list(ids[n - p:])
715+
reps = 0
716+
i = n
717+
while i - p >= 0 and list(ids[i - p:i]) == unit:
718+
reps += 1
719+
i -= p
720+
if reps >= min_reps:
721+
return max((reps - keep_reps) * p, 0)
722+
return 0
723+
724+
674725
# --------------------------------------------------------------------------- #
675726
# The fused spec-decode loop (control flow; MLX/torch ops via injected fns).
676727
# --------------------------------------------------------------------------- #
@@ -689,6 +740,7 @@ def fused_specdecode_generate(
689740
cat_aux_fn: Callable[[Sequence[Any]], Any],
690741
allow_greedy_fallback: bool = True,
691742
on_commit: Optional[Callable[[List[int]], None]] = None,
743+
stop_on_runaway: bool = True,
692744
) -> Dict[str, Any]:
693745
"""Run the fused engine. ``adapter`` must already be prefilled. Per block:
694746
draft from the cached drafter context (B), verify+capture-aux incrementally
@@ -717,6 +769,7 @@ def fused_specdecode_generate(
717769
generated: List[int] = []
718770
accepts: List[int] = []
719771
fallback_to_greedy = False
772+
stopped_on_runaway = False
720773
try:
721774
while len(generated) < gen_tokens:
722775
L = min(block_size, gen_tokens - len(generated))
@@ -792,6 +845,17 @@ def fused_specdecode_generate(
792845
_emit(on_commit, generated)
793846
if any(t in eos for t in commit):
794847
break
848+
# Greedy decoding can collapse into a runaway short-period loop (e.g.
849+
# the **/.2/* markdown-marker wall on code prompts); the drafter then
850+
# trivially predicts the repeats and the greedy verifier accepts them,
851+
# so acceptance stays HIGH while the output is garbage. Stop on it
852+
# instead of emitting an unbounded wall (keeps a short clean tail).
853+
if stop_on_runaway:
854+
drop = _trailing_runaway_drop(generated)
855+
if drop > 0:
856+
del generated[len(generated) - drop:]
857+
stopped_on_runaway = True
858+
break
795859
if (allow_greedy_fallback and len(accepts) >= 2
796860
and (sum(accepts) / len(accepts)) < 1.5):
797861
fallback_to_greedy = True
@@ -810,6 +874,12 @@ def fused_specdecode_generate(
810874
_emit(on_commit, generated)
811875
if tok in eos:
812876
break
877+
if stop_on_runaway:
878+
drop = _trailing_runaway_drop(generated)
879+
if drop > 0:
880+
del generated[len(generated) - drop:]
881+
stopped_on_runaway = True
882+
break
813883
timing["fallback_greedy_s"] += time.perf_counter() - t_fb
814884
finally:
815885
adapter._capture_aux = False
@@ -820,5 +890,6 @@ def fused_specdecode_generate(
820890
"mean_accept_len": (round(sum(accepts) / len(accepts), 3)
821891
if accepts else 0.0),
822892
"decode_tokens": len(generated),
893+
"stopped_on_runaway": stopped_on_runaway,
823894
"time_breakdown_s": {k: round(v, 3) for k, v in timing.items()},
824895
}

inference_engine/bridge/manifest.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -835,6 +835,63 @@ def _harness_preset(
835835
params={"max_new_tokens": ("int:max_new_tokens", "1300")},
836836
validate_reports=True, # §4 liveness + §2.4 quality gate on-device
837837
),
838+
Preset(
839+
name="mlx-kakeya-codegen-degen-probe",
840+
description="Regression probe (guard DISABLED): full f_θ fused engine "
841+
"on the multi-turn 'explain PoW || write PoW in C' chat "
842+
"that originally degenerated, with --fused-no-loop-guard so "
843+
"any greedy markdown-marker collapse is observable. Pairs "
844+
"with mlx-kakeya-codegen-guard-validate (guard ENABLED) to "
845+
"show the guard is what keeps the answer clean. On current "
846+
"code (post wrap-fix) both turns stay coherent.",
847+
command_templates=(
848+
(
849+
"python3", "scripts/research/k3_integrated_niah_eval_mac.py",
850+
"--verifier-path", "${ENV:KAKEYA_MAC_VERIFIER_PATH}",
851+
"--drafter-id", "${ENV:KAKEYA_MAC_DRAFTER_ID}",
852+
"--f-theta-dir", "${ENV:KAKEYA_MAC_FTHETA_DIR}",
853+
"--s5-exact-full-attn", "--fused-specdecode", "--force-f-theta",
854+
"--sink-size", "4", "--window-size", "64", "--block-size", "4",
855+
"--max-new-tokens", "{max_new_tokens}", "--ignore-turn-stop",
856+
"--chat", "--fused-no-loop-guard",
857+
"--chat-scripted",
858+
"请详细解释POW的工作原理||实现一个PoW的代码,用c语言完成",
859+
"--output", "results/research/codegen_degen_2815_longprompt.json",
860+
),
861+
),
862+
timeout_minutes=120,
863+
params={"max_new_tokens": ("int:max_new_tokens", "900")},
864+
validate_reports=False,
865+
),
866+
Preset(
867+
name="mlx-kakeya-codegen-guard-validate",
868+
description="Validate the runaway-loop guard end-to-end: full f_θ fused "
869+
"engine on the multi-turn 'explain PoW || write PoW in C' "
870+
"chat with the guard ENABLED (production default). The "
871+
"answer must stay coherent and never collapse into a marker "
872+
"wall — if a runaway starts, the guard stops it "
873+
"(stopped_on_runaway) leaving a clean tail. Confirmed "
874+
"coherent on current code; byte-identical to the guard-off "
875+
"probe (the guard is inert on healthy output).",
876+
command_templates=(
877+
(
878+
"python3", "scripts/research/k3_integrated_niah_eval_mac.py",
879+
"--verifier-path", "${ENV:KAKEYA_MAC_VERIFIER_PATH}",
880+
"--drafter-id", "${ENV:KAKEYA_MAC_DRAFTER_ID}",
881+
"--f-theta-dir", "${ENV:KAKEYA_MAC_FTHETA_DIR}",
882+
"--s5-exact-full-attn", "--fused-specdecode", "--force-f-theta",
883+
"--sink-size", "4", "--window-size", "64", "--block-size", "4",
884+
"--max-new-tokens", "{max_new_tokens}", "--ignore-turn-stop",
885+
"--chat",
886+
"--chat-scripted",
887+
"请详细解释POW的工作原理||实现一个PoW的代码,用c语言完成",
888+
"--output", "results/research/codegen_guard_validate_2815.json",
889+
),
890+
),
891+
timeout_minutes=120,
892+
params={"max_new_tokens": ("int:max_new_tokens", "900")},
893+
validate_reports=False,
894+
),
838895
Preset(
839896
name="mlx-kakeya-degen-probe",
840897
description="Long-decode regression probe: full f_θ fused engine on a "

scripts/research/k3_integrated_niah_eval_mac.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,15 @@ def parse_args() -> argparse.Namespace:
185185
"stdout (as the interactive CLI does) instead of the "
186186
"per-block [stream] timing lines — lets a non-tty bridge "
187187
"run capture the exact live output format.")
188+
ap.add_argument("--chat-scripted-file", default=None,
189+
help="Like --chat-scripted but reads the (possibly long, "
190+
"'||'-separated) scripted prompt from a UTF-8 file. Lets "
191+
"a long context be a committed fixture instead of a giant "
192+
"manifest argv. Overrides --chat-scripted when set.")
193+
ap.add_argument("--fused-no-loop-guard", action="store_true",
194+
help="DIAGNOSTIC: disable the fused engine's runaway-loop stop "
195+
"(default ON) so a degeneration probe can observe the full "
196+
"collapse. Production chat keeps the guard enabled.")
188197
ap.add_argument("--chat-native-ref", action="store_true",
189198
help="DIAGNOSTIC opt-in: before each chat turn, also run a "
190199
"plain NATIVE greedy AR decode of the SAME prompt for "
@@ -815,27 +824,29 @@ def _gen_turn(pid: List[int], on_commit=None) -> Dict[str, Any]:
815824
evicted_positions=evicted,
816825
prefill_chunk_size=args.prefill_chunk_size, full_kv=args.cuda_trim)
817826
t0 = time.perf_counter()
827+
_guard = not args.fused_no_loop_guard
818828
if mlx_drafter is not None and args.cuda_trim:
819829
res = fused_specdecode_generate_mlx_trim(
820830
adapter, active_drafter, aux_prompt=aux_prompt,
821831
embed_fn=embed_fn, lm_head_fn=lm_head_fn,
822832
gen_tokens=args.max_new_tokens, block_size=args.block_size,
823833
eos_ids=chat_eos, single_fused=args.single_fused,
824-
on_commit=on_commit)
834+
on_commit=on_commit, stop_on_runaway=_guard)
825835
elif mlx_drafter is not None:
826836
res = fused_specdecode_generate_mlx(
827837
adapter, active_drafter, aux_prompt=aux_prompt,
828838
embed_fn=embed_fn, lm_head_fn=lm_head_fn,
829839
gen_tokens=args.max_new_tokens, block_size=args.block_size,
830-
eos_ids=chat_eos, on_commit=on_commit)
840+
eos_ids=chat_eos, on_commit=on_commit,
841+
stop_on_runaway=_guard)
831842
else:
832843
res = fused_specdecode_generate(
833844
adapter, active_drafter, aux_prompt=aux_prompt,
834845
embed_fn=embed_fn, lm_head_fn=lm_head_fn,
835846
gen_tokens=args.max_new_tokens, block_size=args.block_size,
836847
eos_ids=chat_eos, argmax_fn=argmax_fn, arange_fn=arange_fn,
837848
cat_aux_fn=cat_aux_fn, allow_greedy_fallback=False,
838-
on_commit=on_commit)
849+
on_commit=on_commit, stop_on_runaway=_guard)
839850
res["decode_s"] = round(time.perf_counter() - t0, 3)
840851
res["f_theta_ran"] = f_theta_ran
841852
res["f_theta_layers"] = sorted(rk.keys()) if rk else []
@@ -899,8 +910,12 @@ def cb(toks: List[int]) -> None:
899910
file=sys.stderr, flush=True)
900911

901912
history: List[Dict[str, str]] = []
902-
if args.chat_scripted is not None:
903-
turns = [t for t in args.chat_scripted.split("||") if t.strip()]
913+
scripted = args.chat_scripted
914+
if args.chat_scripted_file is not None:
915+
with open(args.chat_scripted_file, encoding="utf-8") as _f:
916+
scripted = _f.read()
917+
if scripted is not None:
918+
turns = [t for t in scripted.split("||") if t.strip()]
904919
transcript = []
905920
for u in turns:
906921
history.append({"role": "user", "content": u})

tests/backends/mlx/test_fused_specdecode.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,68 @@ def __init__(self, offset):
170170
self.max_size = None
171171

172172

173+
def test_trailing_runaway_drop_detects_and_trims_loops():
174+
# 1-token unit repeated 20x -> drop all but keep_reps (default 3).
175+
ids = [1, 2, 3] + [9] * 20
176+
drop = fsd._trailing_runaway_drop(ids)
177+
assert drop == 17 # 20 - 3 kept
178+
# multi-token unit (period 3) repeated 12x -> drop (12-3)*3 = 27.
179+
ids2 = [5, 6] + [7, 8, 9] * 12
180+
assert fsd._trailing_runaway_drop(ids2) == 27
181+
182+
183+
def test_trailing_runaway_drop_is_conservative():
184+
# fewer than min_reps (12) back-to-back -> no trim.
185+
assert fsd._trailing_runaway_drop([9] * 11) == 0
186+
# legitimate non-repeating tail -> no trim.
187+
assert fsd._trailing_runaway_drop(list(range(40))) == 0
188+
# a period that does not tile the very tail -> no trim.
189+
assert fsd._trailing_runaway_drop([1, 2] * 10 + [3]) == 0
190+
# empty / short -> no trim.
191+
assert fsd._trailing_runaway_drop([]) == 0
192+
193+
194+
def test_fused_loop_stops_on_runaway_repeat():
195+
# Drafter keeps proposing the same token; the fake verifier's "+1" truth is
196+
# defeated by making the bonus re-loop: we feed a drafter that always drafts
197+
# the marker token and a verifier that greedily agrees, so the committed
198+
# stream becomes a runaway single-token loop the guard must cut.
199+
class _LoopAdapter(_FakeAdapter):
200+
def forward_block(self, candidate):
201+
# verifier greedily predicts the SAME marker token (42) forever.
202+
if self._capture_aux:
203+
L = len(candidate)
204+
self._last_aux = [torch.zeros(L, self.hidden)]
205+
return [42 for _ in candidate]
206+
207+
adapter = _LoopAdapter(prompt_len=5, first_token=42)
208+
drafter = _FakeDrafter(drafts=[[42, 42, 42]] * 60)
209+
res = fsd.fused_specdecode_generate(
210+
adapter, drafter, gen_tokens=400, block_size=4, eos_ids=(),
211+
allow_greedy_fallback=False, **_loop_kwargs(drafter))
212+
assert res["stopped_on_runaway"] is True
213+
# stopped early with a short clean tail, nowhere near the 400 budget.
214+
assert len(res["tokens"]) < 40
215+
assert set(res["tokens"]) == {42}
216+
217+
218+
def test_fused_loop_runaway_guard_can_be_disabled():
219+
class _LoopAdapter(_FakeAdapter):
220+
def forward_block(self, candidate):
221+
if self._capture_aux:
222+
self._last_aux = [torch.zeros(len(candidate), self.hidden)]
223+
return [42 for _ in candidate]
224+
225+
adapter = _LoopAdapter(prompt_len=5, first_token=42)
226+
drafter = _FakeDrafter(drafts=[[42, 42, 42]] * 200)
227+
res = fsd.fused_specdecode_generate(
228+
adapter, drafter, gen_tokens=120, block_size=4, eos_ids=(),
229+
allow_greedy_fallback=False, stop_on_runaway=False,
230+
**_loop_kwargs(drafter))
231+
assert res["stopped_on_runaway"] is False
232+
assert len(res["tokens"]) == 120 # ran to the full budget
233+
234+
173235
def test_sliding_ring_would_wrap_detects_wrap():
174236
# offset + n_new >= max_size -> the rotating ring becomes non-trimmable.
175237
cache = [_FakeRotating(offset=1022, max_size=1024)]

tests/inference_engine/bridge/test_manifest.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,8 @@ def test_allowlist_contains_exactly_the_documented_presets():
8282
"mlx-env-probe",
8383
"mlx-kakeya-chat-smoke",
8484
"mlx-kakeya-chat-stream-probe",
85+
"mlx-kakeya-codegen-degen-probe",
86+
"mlx-kakeya-codegen-guard-validate",
8587
"mlx-kakeya-degen-probe",
8688
"mlx-kakeya-fused-chat-ftheta",
8789
"mlx-kakeya-fused-chat-smoke",

0 commit comments

Comments
 (0)