3535 restored_prefill_cache ,
3636)
3737
38-
3938# --------------------------------------------------------------------------- #
4039# Component A: capture verifier aux-layer hidden states (no transformers
4140# `output_hidden_states` on MLX → patch the decoder-layer __call__).
@@ -387,6 +386,7 @@ def fused_specdecode_generate_mlx_trim(
387386 eos_ids : Sequence [int ] = (),
388387 single_fused : bool = False ,
389388 on_commit : Optional [Callable [[List [int ]], None ]] = None ,
389+ stop_on_runaway : bool = True ,
390390) -> Dict [str , Any ]:
391391 """CUDA-parity fused spec decode: KEEP accepted K/V, TRIM only the rejected
392392 tail (no rollback, no carry re-forward). Requires the adapter to be
@@ -412,6 +412,7 @@ def fused_specdecode_generate_mlx_trim(
412412 generated : List [int ] = []
413413 accepts : List [int ] = []
414414 block_evals : List [float ] = []
415+ stopped_on_runaway = False
415416 ctx_len = C
416417 try :
417418 while len (generated ) < gen_tokens :
@@ -474,6 +475,12 @@ def fused_specdecode_generate_mlx_trim(
474475 timing ["extend_s" ] += time .perf_counter () - t_extend
475476 if any (t in eos for t in commit ):
476477 break
478+ if stop_on_runaway :
479+ drop = _trailing_runaway_drop (generated )
480+ if drop > 0 :
481+ del generated [len (generated ) - drop :]
482+ stopped_on_runaway = True
483+ break
477484 finally :
478485 adapter ._capture_aux = False
479486 generated = generated [:gen_tokens ]
@@ -483,6 +490,7 @@ def fused_specdecode_generate_mlx_trim(
483490 "mean_accept_len" : (round (sum (accepts ) / len (accepts ), 3 )
484491 if accepts else 0.0 ),
485492 "decode_tokens" : len (generated ),
493+ "stopped_on_runaway" : stopped_on_runaway ,
486494 "loop" : ("mlx_trim_single_fused_probe" if single_fused
487495 else "mlx_trim_keep_accepted_cuda_parity" ),
488496 "single_fused" : bool (single_fused ),
@@ -505,6 +513,7 @@ def fused_specdecode_generate_mlx(
505513 block_size : int ,
506514 eos_ids : Sequence [int ] = (),
507515 on_commit : Optional [Callable [[List [int ]], None ]] = None ,
516+ stop_on_runaway : bool = True ,
508517) -> Dict [str , Any ]:
509518 """All-MLX fused spec decode with ONE host sync per block.
510519
@@ -546,6 +555,7 @@ def fused_specdecode_generate_mlx(
546555
547556 generated : List [int ] = []
548557 accepts : List [int ] = []
558+ stopped_on_runaway = False
549559 # Rollback-carry state: rejected blocks roll the WHOLE forward back
550560 # (rollback_block — see its docstring for why trim is unsound on the
551561 # wrapped sliding ring) and carry the stream-committed-but-not-cached
@@ -630,6 +640,12 @@ def fused_specdecode_generate_mlx(
630640 timing ["extend_s" ] += time .perf_counter () - t_extend
631641 if any (t in eos for t in commit ):
632642 break
643+ if stop_on_runaway :
644+ drop = _trailing_runaway_drop (generated )
645+ if drop > 0 :
646+ del generated [len (generated ) - drop :]
647+ stopped_on_runaway = True
648+ break
633649 finally :
634650 adapter ._capture_aux = False
635651 generated = generated [:gen_tokens ]
@@ -639,6 +655,7 @@ def fused_specdecode_generate_mlx(
639655 "mean_accept_len" : (round (sum (accepts ) / len (accepts ), 3 )
640656 if accepts else 0.0 ),
641657 "decode_tokens" : len (generated ),
658+ "stopped_on_runaway" : stopped_on_runaway ,
642659 "loop" : "mlx_rollback_carry_v3" ,
643660 "time_breakdown_s" : {k : round (v , 3 ) for k , v in timing .items ()},
644661 }
@@ -671,6 +688,40 @@ def _sliding_ring_would_wrap(cache: Any, n_new: int) -> bool:
671688 return False
672689
673690
691+ def _trailing_runaway_drop (
692+ ids : Sequence [int ],
693+ * ,
694+ max_period : int = 8 ,
695+ min_reps : int = 12 ,
696+ keep_reps : int = 3 ,
697+ ) -> int :
698+ """Return how many TRAILING tokens to drop if ``ids`` ends in a runaway
699+ short-period loop, else 0.
700+
701+ A runaway loop is a unit of ``1..max_period`` tokens repeated ``>= min_reps``
702+ times back-to-back at the tail (e.g. the ``**``/``.2``/``*`` markdown-marker
703+ collapse greedy decoding falls into on code prompts). When found, we keep
704+ ``keep_reps`` instances and drop the rest, so callers can stop generation
705+ with a clean tail instead of emitting an unbounded wall of repeats.
706+
707+ Deliberately CONSERVATIVE (>= 12 back-to-back repeats of a <= 8-token unit)
708+ so legitimately repetitive text — numbered lists, ``矿工 A/B/C`` enumerations,
709+ structured code — is never trimmed. Returns 0 when no runaway is present."""
710+ n = len (ids )
711+ for p in range (1 , max_period + 1 ):
712+ if n < p * min_reps :
713+ continue
714+ unit = list (ids [n - p :])
715+ reps = 0
716+ i = n
717+ while i - p >= 0 and list (ids [i - p :i ]) == unit :
718+ reps += 1
719+ i -= p
720+ if reps >= min_reps :
721+ return max ((reps - keep_reps ) * p , 0 )
722+ return 0
723+
724+
674725# --------------------------------------------------------------------------- #
675726# The fused spec-decode loop (control flow; MLX/torch ops via injected fns).
676727# --------------------------------------------------------------------------- #
@@ -689,6 +740,7 @@ def fused_specdecode_generate(
689740 cat_aux_fn : Callable [[Sequence [Any ]], Any ],
690741 allow_greedy_fallback : bool = True ,
691742 on_commit : Optional [Callable [[List [int ]], None ]] = None ,
743+ stop_on_runaway : bool = True ,
692744) -> Dict [str , Any ]:
693745 """Run the fused engine. ``adapter`` must already be prefilled. Per block:
694746 draft from the cached drafter context (B), verify+capture-aux incrementally
@@ -717,6 +769,7 @@ def fused_specdecode_generate(
717769 generated : List [int ] = []
718770 accepts : List [int ] = []
719771 fallback_to_greedy = False
772+ stopped_on_runaway = False
720773 try :
721774 while len (generated ) < gen_tokens :
722775 L = min (block_size , gen_tokens - len (generated ))
@@ -792,6 +845,17 @@ def fused_specdecode_generate(
792845 _emit (on_commit , generated )
793846 if any (t in eos for t in commit ):
794847 break
848+ # Greedy decoding can collapse into a runaway short-period loop (e.g.
849+ # the **/.2/* markdown-marker wall on code prompts); the drafter then
850+ # trivially predicts the repeats and the greedy verifier accepts them,
851+ # so acceptance stays HIGH while the output is garbage. Stop on it
852+ # instead of emitting an unbounded wall (keeps a short clean tail).
853+ if stop_on_runaway :
854+ drop = _trailing_runaway_drop (generated )
855+ if drop > 0 :
856+ del generated [len (generated ) - drop :]
857+ stopped_on_runaway = True
858+ break
795859 if (allow_greedy_fallback and len (accepts ) >= 2
796860 and (sum (accepts ) / len (accepts )) < 1.5 ):
797861 fallback_to_greedy = True
@@ -810,6 +874,12 @@ def fused_specdecode_generate(
810874 _emit (on_commit , generated )
811875 if tok in eos :
812876 break
877+ if stop_on_runaway :
878+ drop = _trailing_runaway_drop (generated )
879+ if drop > 0 :
880+ del generated [len (generated ) - drop :]
881+ stopped_on_runaway = True
882+ break
813883 timing ["fallback_greedy_s" ] += time .perf_counter () - t_fb
814884 finally :
815885 adapter ._capture_aux = False
@@ -820,5 +890,6 @@ def fused_specdecode_generate(
820890 "mean_accept_len" : (round (sum (accepts ) / len (accepts ), 3 )
821891 if accepts else 0.0 ),
822892 "decode_tokens" : len (generated ),
893+ "stopped_on_runaway" : stopped_on_runaway ,
823894 "time_breakdown_s" : {k : round (v , 3 ) for k , v in timing .items ()},
824895 }
0 commit comments