Skip to content

Commit f13594d

Browse files
authored
Merge pull request #147 from FluffyAIcode/AgentMemory/refresh-fused-specdecode-test-fakes-2815
test(mlx-fused): refresh stale fused-loop test fakes and expectations
2 parents be0a35e + 8baa1bf commit f13594d

1 file changed

Lines changed: 36 additions & 7 deletions

File tree

tests/backends/mlx/test_fused_specdecode.py

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,12 @@ def append_token(self, token_id):
5151
self.appends.append(token_id)
5252
return self.next_token_logits
5353

54+
def last_aux_torch_slice(self, start=0, end=None):
55+
# Mirror MLXRestoredIncrementalVerifier.last_aux_torch_slice: per-aux-layer
56+
# torch rows of the most recent forward_block, sliced [start:end].
57+
aux = self._last_aux or [torch.zeros(1, self.hidden)]
58+
return [a[start:end] for a in aux]
59+
5460

5561
class _FakeDrafter:
5662
def __init__(self, drafts):
@@ -89,13 +95,16 @@ def test_fused_loop_full_acceptance():
8995
res = fsd.fused_specdecode_generate(
9096
adapter, drafter, gen_tokens=5, block_size=4, eos_ids=(),
9197
**_loop_kwargs(drafter))
92-
# Block1: candidate=[100,101,102] all accepted (3) + correction 103.
93-
# Block2: candidate=[104] accepted (1) + correction 105 -> truncated to 5.
98+
# Block1: candidate=[100,101,102] fully accepted (3). On FULL acceptance the
99+
# loop reuses block_logits[-1] (=103) as the next distribution and does NOT
100+
# append a correction token. next=103.
101+
# Block2: L=2 -> candidate=[103,200]; accept 103 (1), reject 200, correction
102+
# =104 appended -> commit [103,104]; total 5 tokens.
94103
assert res["tokens"] == [100, 101, 102, 103, 104]
95104
assert res["blocks"] == 2
96105
assert res["mean_accept_len"] == 2.0 # (3 + 1) / 2
97106
assert adapter.commits[0] == (3, 3) # block1 verify-commit
98-
assert adapter.appends == [103, 105] # one correction per block
107+
assert adapter.appends == [104] # only block2's correction
99108
# capture flag toggled on during loop, off after.
100109
assert adapter._capture_aux is False
101110
# context K/V extended once per block.
@@ -122,9 +131,27 @@ def test_fused_loop_stops_on_eos():
122131
res = fsd.fused_specdecode_generate(
123132
adapter, drafter, gen_tokens=50, block_size=4, eos_ids=(103,),
124133
**_loop_kwargs(drafter))
125-
# correction 103 is EOS -> stop after first block.
134+
# Block1 fully accepts [100,101,102] (no correction appended on full accept),
135+
# leaving next=103. Block2's bonus is then 103 (EOS), committed and stopped.
126136
assert res["tokens"] == [100, 101, 102, 103]
127-
assert res["blocks"] == 1
137+
assert res["blocks"] == 2
138+
139+
140+
def test_fused_loop_greedy_fallback_on_low_acceptance():
141+
adapter = _FakeAdapter(prompt_len=5, first_token=100)
142+
# Each block accepts only the bonus (drafts mismatch the verifier), so after
143+
# 2 blocks mean acceptance = 1.0 < 1.5 and the loop switches to plain greedy
144+
# to finish the budget (no aux capture, no drafter extension past the blocks).
145+
drafter = _FakeDrafter(drafts=[[999, 999, 999], [999, 999, 999]])
146+
res = fsd.fused_specdecode_generate(
147+
adapter, drafter, gen_tokens=10, block_size=4, eos_ids=(),
148+
**_loop_kwargs(drafter))
149+
# blocks 1-2 commit [100,101] then [102,103]; greedy fallback adds 104..109.
150+
assert res["tokens"] == [100, 101, 102, 103, 104, 105, 106, 107, 108, 109]
151+
assert res["blocks"] == 2 # only the speculative blocks are counted
152+
assert res["mean_accept_len"] == 1.0 # (1 + 1) / 2
153+
assert adapter._capture_aux is False # turned off for the greedy tail
154+
assert drafter.extend_calls == 2 # extended only during the spec blocks
128155

129156

130157
# =========================================================================== #
@@ -307,8 +334,10 @@ def as_linear(self, h): return "L"
307334
adapter._capture_aux = True
308335
logits = adapter.forward_block([7, 8])
309336
assert logits == "ROW" # _Model returns row at [0]
310-
# aux = [hs[1]] = [layer-0 output]; bridged after stripping batch ([0]).
311-
assert adapter._last_aux == [("torch", ("row", 0))]
337+
# aux = [hs[1]] = [layer-0 output], captured LAZILY in MX (_last_aux_mx);
338+
# _last_aux stays None and the torch bridge happens on demand.
339+
assert adapter._last_aux is None
340+
assert adapter.last_aux_torch_slice() == [("torch", ("row", 0))]
312341

313342
# commit_or_truncate trims by (forwarded - accepted) and advances _past_len
314343
adapter.commit_or_truncate(forwarded=2, accepted=1)

0 commit comments

Comments
 (0)