@@ -51,6 +51,12 @@ def append_token(self, token_id):
5151 self .appends .append (token_id )
5252 return self .next_token_logits
5353
54+ def last_aux_torch_slice (self , start = 0 , end = None ):
55+ # Mirror MLXRestoredIncrementalVerifier.last_aux_torch_slice: per-aux-layer
56+ # torch rows of the most recent forward_block, sliced [start:end].
57+ aux = self ._last_aux or [torch .zeros (1 , self .hidden )]
58+ return [a [start :end ] for a in aux ]
59+
5460
5561class _FakeDrafter :
5662 def __init__ (self , drafts ):
@@ -89,13 +95,16 @@ def test_fused_loop_full_acceptance():
8995 res = fsd .fused_specdecode_generate (
9096 adapter , drafter , gen_tokens = 5 , block_size = 4 , eos_ids = (),
9197 ** _loop_kwargs (drafter ))
92- # Block1: candidate=[100,101,102] all accepted (3) + correction 103.
93- # Block2: candidate=[104] accepted (1) + correction 105 -> truncated to 5.
98+ # Block1: candidate=[100,101,102] fully accepted (3). On FULL acceptance the
99+ # loop reuses block_logits[-1] (=103) as the next distribution and does NOT
100+ # append a correction token. next=103.
101+ # Block2: L=2 -> candidate=[103,200]; accept 103 (1), reject 200, correction
102+ # =104 appended -> commit [103,104]; total 5 tokens.
94103 assert res ["tokens" ] == [100 , 101 , 102 , 103 , 104 ]
95104 assert res ["blocks" ] == 2
96105 assert res ["mean_accept_len" ] == 2.0 # (3 + 1) / 2
97106 assert adapter .commits [0 ] == (3 , 3 ) # block1 verify-commit
98- assert adapter .appends == [103 , 105 ] # one correction per block
107+ assert adapter .appends == [104 ] # only block2's correction
99108 # capture flag toggled on during loop, off after.
100109 assert adapter ._capture_aux is False
101110 # context K/V extended once per block.
@@ -122,9 +131,27 @@ def test_fused_loop_stops_on_eos():
122131 res = fsd .fused_specdecode_generate (
123132 adapter , drafter , gen_tokens = 50 , block_size = 4 , eos_ids = (103 ,),
124133 ** _loop_kwargs (drafter ))
125- # correction 103 is EOS -> stop after first block.
134+ # Block1 fully accepts [100,101,102] (no correction appended on full accept),
135+ # leaving next=103. Block2's bonus is then 103 (EOS), committed and stopped.
126136 assert res ["tokens" ] == [100 , 101 , 102 , 103 ]
127- assert res ["blocks" ] == 1
137+ assert res ["blocks" ] == 2
138+
139+
140+ def test_fused_loop_greedy_fallback_on_low_acceptance ():
141+ adapter = _FakeAdapter (prompt_len = 5 , first_token = 100 )
142+ # Each block accepts only the bonus (drafts mismatch the verifier), so after
143+ # 2 blocks mean acceptance = 1.0 < 1.5 and the loop switches to plain greedy
144+ # to finish the budget (no aux capture, no drafter extension past the blocks).
145+ drafter = _FakeDrafter (drafts = [[999 , 999 , 999 ], [999 , 999 , 999 ]])
146+ res = fsd .fused_specdecode_generate (
147+ adapter , drafter , gen_tokens = 10 , block_size = 4 , eos_ids = (),
148+ ** _loop_kwargs (drafter ))
149+ # blocks 1-2 commit [100,101] then [102,103]; greedy fallback adds 104..109.
150+ assert res ["tokens" ] == [100 , 101 , 102 , 103 , 104 , 105 , 106 , 107 , 108 , 109 ]
151+ assert res ["blocks" ] == 2 # only the speculative blocks are counted
152+ assert res ["mean_accept_len" ] == 1.0 # (1 + 1) / 2
153+ assert adapter ._capture_aux is False # turned off for the greedy tail
154+ assert drafter .extend_calls == 2 # extended only during the spec blocks
128155
129156
130157# =========================================================================== #
@@ -307,8 +334,10 @@ def as_linear(self, h): return "L"
307334 adapter ._capture_aux = True
308335 logits = adapter .forward_block ([7 , 8 ])
309336 assert logits == "ROW" # _Model returns row at [0]
310- # aux = [hs[1]] = [layer-0 output]; bridged after stripping batch ([0]).
311- assert adapter ._last_aux == [("torch" , ("row" , 0 ))]
337+ # aux = [hs[1]] = [layer-0 output], captured LAZILY in MX (_last_aux_mx);
338+ # _last_aux stays None and the torch bridge happens on demand.
339+ assert adapter ._last_aux is None
340+ assert adapter .last_aux_torch_slice () == [("torch" , ("row" , 0 ))]
312341
313342 # commit_or_truncate trims by (forwarded - accepted) and advances _past_len
314343 adapter .commit_or_truncate (forwarded = 2 , accepted = 1 )
0 commit comments