Skip to content

Commit 5920155

Browse files
authored
Merge pull request #31 from FluffyAIcode/AgentMemory/v030-pr7-1-cache-parallel-token-seq-8e7f
PR 7-1 (ADR 0007): cached_token_sequence + INV-1 on both verifiers
2 parents 00172ae + 56e8c5c commit 5920155

4 files changed

Lines changed: 401 additions & 2 deletions

File tree

inference_engine/backends/mlx/verifier.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,14 @@ def __init__(self, config: Optional[VerifierConfig] = None) -> None:
9191
self.cache_logical_size: int = 0
9292
self.next_global_position: int = 0
9393
self.next_token_logits: Optional[torch.Tensor] = None
94+
# Parallel record of the token id at every K/V cache slot, in
95+
# the same physical order as ``self.cache[*].keys``. See the CPU
96+
# verifier for the full motivation; in short, this is required
97+
# by ADR 0007 §2.2 (path-selection needs token-id-level
98+
# comparison against the cache) and §2.9 INV-1 (parallel-
99+
# sequence consistency). Maintained synchronously with the
100+
# K/V tensors by every cache mutation method below.
101+
self.cached_token_sequence: List[int] = []
94102

95103
self.quantization: QuantizationInfo = detect_quantization(self.model)
96104
self.stats = VerifierStats(weight_bytes=self.quantization.total_weight_bytes)
@@ -106,6 +114,8 @@ def reset(self) -> None:
106114
self.cache_logical_size = 0
107115
self.next_global_position = 0
108116
self.next_token_logits = None
117+
self.cached_token_sequence = []
118+
self._assert_cache_invariant_1()
109119

110120
def prefill(self, prompt_ids: List[int]) -> None:
111121
if not prompt_ids:
@@ -125,9 +135,16 @@ def prefill(self, prompt_ids: List[int]) -> None:
125135
self.next_token_logits = mx_to_torch(logits_mx[0, -1])
126136
self.next_global_position = L
127137
self.cache_logical_size = self._cache_buffer_size()
138+
# Compute the post-trim parallel token sequence directly. The
139+
# MLX SinkWindowKVCache trims inside update_and_fetch on every
140+
# forward, so by the time we get here the per-layer K/V tensors
141+
# already hold the sink+window slice of ``prompt_ids``. We
142+
# mirror that slice on cached_token_sequence so INV-1 holds.
143+
self.cached_token_sequence = self._sink_window_slice(prompt_ids)
128144
self._record_peak_kv()
129145
self.stats.forward_calls += 1
130146
self.stats.tokens_consumed += L
147+
self._assert_cache_invariant_1()
131148

132149
def forward_block(self, tokens: List[int]) -> torch.Tensor:
133150
if self.cache is None:
@@ -150,9 +167,15 @@ def forward_block(self, tokens: List[int]) -> torch.Tensor:
150167
# SinkWindowKVCache.update_and_fetch trim. Read directly from
151168
# the cache rather than tracking it ourselves.
152169
self.cache_logical_size = self._cache_buffer_size()
170+
# Mirror the same trim on the parallel sequence: take the
171+
# current sequence concatenated with the new tokens and apply
172+
# the sink+window slice.
173+
extended = self.cached_token_sequence + list(tokens)
174+
self.cached_token_sequence = self._sink_window_slice(extended)
153175
block_logits = mx_to_torch(logits_mx[0]) # [L, V]
154176
self.stats.forward_calls += 1
155177
self.stats.tokens_consumed += L
178+
self._assert_cache_invariant_1()
156179
return block_logits
157180

158181
def commit_or_truncate(self, forwarded: int, accepted: int) -> None:
@@ -169,9 +192,12 @@ def commit_or_truncate(self, forwarded: int, accepted: int) -> None:
169192
f"per-layer trim mismatch (asked drop={drop}, got {trims}); "
170193
"SinkWindowKVCache state diverged across layers"
171194
)
195+
# Mirror the tail truncation on the parallel sequence.
196+
self.cached_token_sequence = self.cached_token_sequence[:-drop]
172197
self.cache_logical_size = self._cache_buffer_size()
173198
self.next_global_position += accepted
174199
self._record_peak_kv()
200+
self._assert_cache_invariant_1()
175201

176202
def append_token(self, token_id: int) -> torch.Tensor:
177203
logits = self.forward_block([token_id])
@@ -187,6 +213,46 @@ def _cache_buffer_size(self) -> int:
187213
return 0
188214
return cache_ops.cache_seq_length(self.cache)
189215

216+
def _sink_window_slice(self, sequence: List[int]) -> List[int]:
217+
"""Return ``sequence`` after the sink+window trim that the K/V
218+
cache applies.
219+
220+
Mirrors ``SinkWindowKVCache.update_and_fetch``'s trim logic at
221+
the token-id level: if the input length exceeds the budget,
222+
keep the first ``sink_size`` entries and the last
223+
``window_size`` entries; otherwise return unchanged.
224+
"""
225+
budget = self.config.sink_size + self.config.window_size
226+
if len(sequence) <= budget:
227+
return list(sequence)
228+
return (
229+
list(sequence[: self.config.sink_size])
230+
+ list(sequence[-self.config.window_size :])
231+
)
232+
233+
def _assert_cache_invariant_1(self) -> None:
234+
"""ADR 0007 §2.9 INV-1: parallel-sequence consistency.
235+
236+
After every cache mutation, ``len(self.cached_token_sequence)``
237+
must equal the K/V tensor sequence dimension. Violation
238+
indicates a bug in the verifier's cache-mutation path; per ADR
239+
0007 §2.9 the implementation must raise — never silently
240+
recover, never fall back, never re-sync.
241+
"""
242+
actual = len(self.cached_token_sequence)
243+
expected = self._cache_buffer_size()
244+
if actual != expected:
245+
raise AssertionError(
246+
f"INV-1 violated (parallel-sequence consistency): "
247+
f"cached_token_sequence has {actual} entries but K/V "
248+
f"cache seq dim is {expected}. This is a bug in the "
249+
f"verifier's cache-mutation path; per ADR 0007 §2.9 it "
250+
f"must surface as a critical error rather than be "
251+
f"silently recovered. cache_logical_size="
252+
f"{self.cache_logical_size}, "
253+
f"next_global_position={self.next_global_position}."
254+
)
255+
190256
def live_kv_bytes(self) -> int:
191257
"""Return the current size of the verifier's live KV cache in bytes.
192258

kv_cache_proposer/verifier.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,23 @@ def __init__(self, config: Optional[VerifierConfig] = None) -> None:
7676
# Logits at `next_global_position` predicting the next token. Updated
7777
# after every forward pass.
7878
self.next_token_logits: Optional[torch.Tensor] = None
79+
# Parallel record of the token id at every K/V cache slot, in the
80+
# same physical order as ``self.cache.layers[*].keys``. Maintained
81+
# synchronously with the K/V tensors by every cache mutation
82+
# method below. Required by ADR 0007 §2.2 + §2.9 INV-1: the
83+
# path-selection algorithm (PR 7-2) needs token-id-level prefix
84+
# matching against the cache, and the K/V tensors don't expose
85+
# token ids.
86+
#
87+
# Storage: at most ``sink_size + window_size`` int entries, so
88+
# bounded at the same constant the K/V cache is bounded at
89+
# (e.g. 68 entries × 8 bytes per Python int = 544 bytes,
90+
# negligible vs the 7.4 MiB K/V).
91+
#
92+
# Invariant INV-1 (ADR 0007 §2.9): after every cache mutation,
93+
# ``len(self.cached_token_sequence)`` equals the K/V tensor
94+
# sequence dimension. Enforced by ``_assert_cache_invariant_1``.
95+
self.cached_token_sequence: List[int] = []
7996

8097
self.stats = VerifierStats(
8198
weight_bytes=sum(p.numel() * p.element_size() for p in self.model.parameters())
@@ -87,6 +104,8 @@ def reset(self) -> None:
87104
self.cache_logical_size = 0
88105
self.next_global_position = 0
89106
self.next_token_logits = None
107+
self.cached_token_sequence = []
108+
self._assert_cache_invariant_1()
90109

91110
@torch.no_grad()
92111
def prefill(self, prompt_ids: List[int]) -> None:
@@ -112,11 +131,19 @@ def prefill(self, prompt_ids: List[int]) -> None:
112131
self.next_global_position = L
113132
self.next_token_logits = outputs.logits[0, -1].clone()
114133

134+
# Update parallel token sequence in lockstep with the K/V cache.
135+
# After this prefill the cache holds K/V for all L tokens; the
136+
# subsequent ``_trim_cache_in_place`` will drop middle entries
137+
# to enforce sink+window. We mirror that exact transformation
138+
# on ``cached_token_sequence``.
139+
self.cached_token_sequence = list(prompt_ids)
140+
115141
self._record_peak_activation(outputs.logits)
116142
self._trim_cache_in_place()
117143
self._record_peak_kv()
118144
self.stats.forward_calls += 1
119145
self.stats.tokens_consumed += L
146+
self._assert_cache_invariant_1()
120147

121148
@torch.no_grad()
122149
def forward_block(self, tokens: List[int]) -> torch.Tensor:
@@ -153,9 +180,13 @@ def forward_block(self, tokens: List[int]) -> torch.Tensor:
153180
self.cache = outputs.past_key_values
154181
# Cache provisionally has cache_start + L slots until commit/truncate.
155182
self.cache_logical_size = cache_start + L
183+
# Mirror the provisional extension on the parallel sequence;
184+
# commit_or_truncate will drop the unaccepted tail in lockstep.
185+
self.cached_token_sequence = self.cached_token_sequence + list(tokens)
156186
self._record_peak_activation(outputs.logits)
157187
self.stats.forward_calls += 1
158188
self.stats.tokens_consumed += L
189+
self._assert_cache_invariant_1()
159190
# Don't trim yet — caller decides how many tokens were accepted.
160191
return outputs.logits[0].clone() # [L, V]
161192

@@ -177,10 +208,13 @@ def commit_or_truncate(
177208
drop = forwarded - accepted
178209
if drop > 0:
179210
self._truncate_tail_in_place(drop)
211+
# Mirror the tail truncation on the parallel sequence.
212+
self.cached_token_sequence = self.cached_token_sequence[:-drop]
180213
self.cache_logical_size -= drop
181214
self.next_global_position += accepted
182215
self._trim_cache_in_place()
183216
self._record_peak_kv()
217+
self._assert_cache_invariant_1()
184218

185219
@torch.no_grad()
186220
def append_token(self, token_id: int) -> torch.Tensor:
@@ -226,6 +260,12 @@ def _trim_cache_in_place(self) -> None:
226260
# peak_kv_bytes would over-report.
227261
layer.keys = torch.cat([sink_k, tail_k], dim=2).contiguous()
228262
layer.values = torch.cat([sink_v, tail_v], dim=2).contiguous()
263+
# Mirror the same sink+window slice on the parallel token sequence
264+
# so cached_token_sequence stays in lockstep with the K/V tensors.
265+
self.cached_token_sequence = (
266+
self.cached_token_sequence[:sink]
267+
+ self.cached_token_sequence[-keep_window:]
268+
)
229269
self.cache_logical_size = budget
230270

231271
def _truncate_tail_in_place(self, drop: int) -> None:
@@ -246,6 +286,44 @@ def _truncate_tail_in_place(self, drop: int) -> None:
246286
layer.keys = keys[:, :, :keep, :].contiguous()
247287
layer.values = values[:, :, :keep, :].contiguous()
248288

289+
def _cache_seq_length(self) -> int:
290+
"""Return the seq dim of the cache K/V tensors, or 0 if empty.
291+
292+
Reads from the first non-empty layer; all layers share the same
293+
seq dim by construction (every K/V mutation in this class
294+
applies the same shape transformation across all layers).
295+
"""
296+
if self.cache is None:
297+
return 0
298+
for layer in self.cache.layers:
299+
keys = getattr(layer, "keys", None)
300+
if keys is not None:
301+
return int(keys.shape[2])
302+
return 0
303+
304+
def _assert_cache_invariant_1(self) -> None:
305+
"""ADR 0007 §2.9 INV-1: parallel-sequence consistency.
306+
307+
After every cache mutation, ``len(self.cached_token_sequence)``
308+
must equal the K/V tensor sequence dimension. Violation
309+
indicates a bug in the cache-mutation path; per ADR 0007 §2.9
310+
the implementation must raise — never silently recover, never
311+
fall back, never re-sync.
312+
"""
313+
actual = len(self.cached_token_sequence)
314+
expected = self._cache_seq_length()
315+
if actual != expected:
316+
raise AssertionError(
317+
f"INV-1 violated (parallel-sequence consistency): "
318+
f"cached_token_sequence has {actual} entries but K/V "
319+
f"cache seq dim is {expected}. This is a bug in the "
320+
f"verifier's cache-mutation path; per ADR 0007 §2.9 it "
321+
f"must surface as a critical error rather than be "
322+
f"silently recovered. cache_logical_size="
323+
f"{self.cache_logical_size}, "
324+
f"next_global_position={self.next_global_position}."
325+
)
326+
249327
def live_kv_bytes(self) -> int:
250328
"""Return the current size of the verifier's live KV cache in bytes.
251329

tests/backends/mlx/test_verifier.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,122 @@ def test_live_kv_bytes_nonzero_after_prefill() -> None:
273273
assert v.stats.peak_kv_bytes == n
274274

275275

276+
# ---------------------------------------------------------------------------
277+
# ADR 0007 §2.2 + §2.9 — cached_token_sequence + INV-1
278+
# ---------------------------------------------------------------------------
279+
280+
281+
def test_mlx_cached_token_sequence_empty_after_construction() -> None:
282+
v = _build_mlx_verifier()
283+
assert v.cached_token_sequence == []
284+
v._assert_cache_invariant_1()
285+
286+
287+
def test_mlx_cached_token_sequence_populated_after_short_prefill() -> None:
288+
v = _build_mlx_verifier(sink=2, window=8)
289+
prompt = list(range(5)) # 5 < sink+window = 10
290+
v.prefill(prompt)
291+
assert v.cached_token_sequence == prompt
292+
v._assert_cache_invariant_1()
293+
294+
295+
def test_mlx_cached_token_sequence_trimmed_after_long_prefill() -> None:
296+
v = _build_mlx_verifier(sink=2, window=4)
297+
prompt = list(range(20)) # 20 > sink+window = 6
298+
v.prefill(prompt)
299+
expected = prompt[:2] + prompt[-4:]
300+
assert v.cached_token_sequence == expected
301+
v._assert_cache_invariant_1()
302+
303+
304+
def test_mlx_cached_token_sequence_extends_on_forward_block() -> None:
305+
"""``forward_block`` extends the cache; the parallel sequence
306+
extends in lockstep, then the same sink+window slice that the
307+
K/V tensors apply is applied here too."""
308+
v = _build_mlx_verifier(sink=2, window=8)
309+
v.prefill([0, 1, 2, 3])
310+
v.forward_block([4, 5])
311+
# 6 entries, all under budget=10
312+
assert v.cached_token_sequence == [0, 1, 2, 3, 4, 5]
313+
v._assert_cache_invariant_1()
314+
315+
316+
def test_mlx_cached_token_sequence_drops_rejected_tail_on_partial_accept() -> None:
317+
v = _build_mlx_verifier(sink=2, window=8)
318+
v.prefill([0, 1, 2, 3])
319+
v.forward_block([4, 5, 6])
320+
v.commit_or_truncate(forwarded=3, accepted=1)
321+
assert v.cached_token_sequence == [0, 1, 2, 3, 4]
322+
v._assert_cache_invariant_1()
323+
324+
325+
def test_mlx_cached_token_sequence_after_append_token() -> None:
326+
v = _build_mlx_verifier(sink=2, window=8)
327+
v.prefill([0, 1, 2, 3])
328+
v.append_token(99)
329+
assert v.cached_token_sequence == [0, 1, 2, 3, 99]
330+
v._assert_cache_invariant_1()
331+
332+
333+
def test_mlx_cached_token_sequence_cleared_on_reset() -> None:
334+
v = _build_mlx_verifier(sink=2, window=8)
335+
v.prefill([0, 1, 2, 3])
336+
assert v.cached_token_sequence != []
337+
v.reset()
338+
assert v.cached_token_sequence == []
339+
v._assert_cache_invariant_1()
340+
341+
342+
def test_mlx_inv_1_violation_raises_assertion_error() -> None:
343+
v = _build_mlx_verifier(sink=2, window=8)
344+
v.prefill([0, 1, 2, 3])
345+
v.cached_token_sequence = v.cached_token_sequence + [999]
346+
with pytest.raises(AssertionError, match="INV-1 violated"):
347+
v._assert_cache_invariant_1()
348+
349+
350+
def test_mlx_inv_1_assertion_message_carries_diagnostic_state() -> None:
351+
"""The error message must expose actual vs expected lengths plus
352+
the verifier's logical-position counters so a bug report can be
353+
triaged from the message alone."""
354+
v = _build_mlx_verifier()
355+
v.prefill([0, 1, 2, 3])
356+
v.cached_token_sequence = v.cached_token_sequence + [42, 43]
357+
with pytest.raises(AssertionError) as exc:
358+
v._assert_cache_invariant_1()
359+
msg = str(exc.value)
360+
assert "INV-1" in msg
361+
assert "cached_token_sequence" in msg
362+
assert "cache_logical_size=" in msg
363+
assert "next_global_position=" in msg
364+
365+
366+
def test_mlx_inv_1_holds_when_cache_is_none() -> None:
367+
"""The pre-prefill state (cache None, sequence []) is the trivial
368+
INV-1 satisfaction — must not raise."""
369+
v = _build_mlx_verifier()
370+
assert v.cache is None
371+
assert v.cached_token_sequence == []
372+
v._assert_cache_invariant_1()
373+
374+
375+
def test_mlx_sink_window_slice_below_budget_returns_input_unchanged() -> None:
376+
"""The internal helper short-circuits when sequence fits in
377+
sink+window."""
378+
v = _build_mlx_verifier(sink=2, window=4)
379+
seq = [10, 20, 30]
380+
out = v._sink_window_slice(seq)
381+
assert out == seq
382+
assert out is not seq # returns a copy
383+
384+
385+
def test_mlx_sink_window_slice_above_budget_keeps_sink_plus_tail() -> None:
386+
v = _build_mlx_verifier(sink=2, window=4)
387+
seq = list(range(20))
388+
out = v._sink_window_slice(seq)
389+
assert out == seq[:2] + seq[-4:]
390+
391+
276392
def test_record_peak_activation_grows_only() -> None:
277393
v = _build_mlx_verifier()
278394
a = mx.zeros((1, 4, 32), dtype=mx.bfloat16)

0 commit comments

Comments
 (0)