Skip to content

Commit b00cbe5

Browse files
authored
Merge pull request #32 from FluffyAIcode/AgentMemory/v030-pr7-2-path-select-incremental-prefill-8e7f
PR 7-2 (ADR 0007): path_select + prefill_incremental + INV-2
2 parents 5920155 + f3b3c64 commit b00cbe5

6 files changed

Lines changed: 783 additions & 0 deletions

File tree

inference_engine/backends/mlx/verifier.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,70 @@ def commit_or_truncate(self, forwarded: int, accepted: int) -> None:
199199
self._record_peak_kv()
200200
self._assert_cache_invariant_1()
201201

202+
def path_select(self, prompt: List[int]) -> "PathPlan":
203+
"""Select between continuation and new-session paths for ``prompt``.
204+
205+
Same contract as the CPU verifier; see
206+
:meth:`kv_cache_proposer.verifier.SinkWindowVerifier.path_select`
207+
for full semantics. Implements ADR 0007 §2.4 and asserts
208+
INV-2 (§2.9).
209+
"""
210+
from kv_cache_proposer.path_plan import ( # avoid circular
211+
ContinuationPlan,
212+
NewSession,
213+
)
214+
215+
if not prompt:
216+
raise ValueError("prompt must be non-empty")
217+
prompt_list = list(prompt)
218+
219+
if self.cache is None or self.cache_logical_size == 0:
220+
return NewSession(prompt=prompt_list)
221+
222+
cache_end = self.next_global_position
223+
if len(prompt_list) < cache_end:
224+
return NewSession(prompt=prompt_list)
225+
if not self._prompt_matches_cached_positions(prompt_list):
226+
return NewSession(prompt=prompt_list)
227+
228+
skip_n = cache_end
229+
new_tokens = prompt_list[skip_n:]
230+
231+
if skip_n != self.next_global_position:
232+
raise AssertionError(
233+
f"INV-2 violated (position monotonicity): planned "
234+
f"skip_n={skip_n} but next_global_position="
235+
f"{self.next_global_position}. Continuation must "
236+
f"extend exactly from the cache's logical end. This "
237+
f"is a bug in path_select; ADR 0007 §2.9 forbids "
238+
f"silent recovery. cache_logical_size="
239+
f"{self.cache_logical_size}, "
240+
f"cached_token_sequence_len="
241+
f"{len(self.cached_token_sequence)}."
242+
)
243+
244+
return ContinuationPlan(skip_n=skip_n, new_tokens=new_tokens)
245+
246+
def prefill_incremental(self, new_tokens: List[int]) -> None:
247+
"""Run incremental prefill on ``new_tokens``, reusing cached state.
248+
249+
Same contract as the CPU verifier; see
250+
:meth:`kv_cache_proposer.verifier.SinkWindowVerifier.prefill_incremental`.
251+
"""
252+
if self.cache is None:
253+
raise RuntimeError(
254+
"prefill_incremental called before any prefill; cache "
255+
"is None. Call path_select first and route NewSession "
256+
"to prefill() instead."
257+
)
258+
if not new_tokens:
259+
return
260+
block_logits = self.forward_block(list(new_tokens))
261+
self.commit_or_truncate(
262+
forwarded=len(new_tokens), accepted=len(new_tokens)
263+
)
264+
self.next_token_logits = block_logits[-1].clone()
265+
202266
def append_token(self, token_id: int) -> torch.Tensor:
203267
logits = self.forward_block([token_id])
204268
self.commit_or_truncate(forwarded=1, accepted=1)
@@ -213,6 +277,42 @@ def _cache_buffer_size(self) -> int:
213277
return 0
214278
return cache_ops.cache_seq_length(self.cache)
215279

280+
def _cached_global_positions(self) -> List[int]:
281+
"""Global token positions currently held in the cache.
282+
283+
See :meth:`kv_cache_proposer.verifier.SinkWindowVerifier._cached_global_positions`
284+
for semantics.
285+
"""
286+
n = self.next_global_position
287+
if n == 0:
288+
return []
289+
budget = self.config.sink_size + self.config.window_size
290+
if n <= budget:
291+
return list(range(n))
292+
sink_positions = list(range(self.config.sink_size))
293+
window_start = n - self.config.window_size
294+
window_positions = list(range(window_start, n))
295+
return sink_positions + window_positions
296+
297+
def _prompt_matches_cached_positions(self, prompt: List[int]) -> bool:
298+
"""Token-id-level check for ADR 0007 §2.4.a.2."""
299+
positions = self._cached_global_positions()
300+
if len(positions) != len(self.cached_token_sequence):
301+
raise AssertionError(
302+
f"_prompt_matches_cached_positions: position list of "
303+
f"length {len(positions)} disagrees with parallel "
304+
f"sequence of length {len(self.cached_token_sequence)}; "
305+
f"INV-1 should have caught this earlier"
306+
)
307+
for cache_idx, global_pos in enumerate(positions):
308+
if global_pos >= len(prompt):
309+
return False
310+
if int(prompt[global_pos]) != int(
311+
self.cached_token_sequence[cache_idx]
312+
):
313+
return False
314+
return True
315+
216316
def _sink_window_slice(self, sequence: List[int]) -> List[int]:
217317
"""Return ``sequence`` after the sink+window trim that the K/V
218318
cache applies.

kv_cache_proposer/path_plan.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
"""Path-selection result types for cross-request KV cache reuse.
2+
3+
Implements the result types ``ContinuationPlan`` and ``NewSession``
4+
described in ADR 0007 §2.4. Both CPU and MLX verifiers expose a
5+
``path_select(prompt)`` method returning one of these.
6+
7+
ADR 0007 §2.4 contract recap:
8+
9+
Continuation precondition (both must hold):
10+
1. ``len(prompt) >= cache_logical_end`` (the new prompt extends
11+
at or past the position the cache already covers).
12+
2. The new prompt's tokens at every cached logical position
13+
equal ``cached_token_sequence`` at the corresponding slot.
14+
15+
When the precondition holds → ``ContinuationPlan(skip_n,
16+
new_tokens)``: the verifier should run ``prefill_incremental``
17+
on ``new_tokens`` (the suffix of ``prompt`` after the cached
18+
prefix), reusing the existing K/V cache state.
19+
20+
When the precondition fails (cold start, shorter history,
21+
diverging history) → ``NewSession(prompt)``: the verifier should
22+
run a full ``prefill(prompt)`` (which calls ``reset()`` first and
23+
rebuilds the cache from scratch).
24+
25+
The two paths are first-class deterministic actions per ADR 0007
26+
§2.4.c. Selecting NewSession is **not** a fallback from
27+
ContinuationPlan — both produce bit-identical output for their
28+
input class (per §2.7); the only difference is computational cost.
29+
"""
30+
31+
from __future__ import annotations
32+
33+
from dataclasses import dataclass, field
34+
from typing import List, Union
35+
36+
37+
@dataclass(frozen=True)
38+
class ContinuationPlan:
39+
"""Continuation path: reuse the cached prefix.
40+
41+
Attributes
42+
----------
43+
skip_n
44+
Number of tokens at the start of the new prompt that are
45+
already covered by the cache. The verifier should NOT
46+
re-prefill these. By construction
47+
``skip_n == verifier.next_global_position`` at the moment
48+
``path_select`` ran (see ADR 0007 §2.9 INV-2).
49+
new_tokens
50+
The suffix of the new prompt that is NOT yet in the cache.
51+
Length = ``len(prompt) - skip_n``. Always non-empty when the
52+
plan is returned (a continuation that adds zero new tokens
53+
is encoded as ``ContinuationPlan(skip_n=len(prompt),
54+
new_tokens=[])`` only in the unusual case where the new
55+
prompt exactly matches the cache state — most callers
56+
handle that as a no-op forward).
57+
"""
58+
59+
skip_n: int
60+
new_tokens: List[int] = field(default_factory=list)
61+
62+
def __post_init__(self) -> None:
63+
if self.skip_n < 0:
64+
raise ValueError(f"skip_n must be >= 0, got {self.skip_n}")
65+
# Note: new_tokens may be empty (the rare exact-match case).
66+
67+
68+
@dataclass(frozen=True)
69+
class NewSession:
70+
"""New-session path: reset the cache and run full prefill.
71+
72+
Triggered when any §2.4.b sub-case applies: cold start, shorter
73+
history, or diverging history.
74+
75+
Attributes
76+
----------
77+
prompt
78+
The full prompt to prefill. Always non-empty (the verifier
79+
rejects empty prompts upstream).
80+
"""
81+
82+
prompt: List[int]
83+
84+
def __post_init__(self) -> None:
85+
if not self.prompt:
86+
raise ValueError("NewSession.prompt must be non-empty")
87+
88+
89+
PathPlan = Union[ContinuationPlan, NewSession]
90+
91+
92+
__all__ = ["ContinuationPlan", "NewSession", "PathPlan"]

kv_cache_proposer/verifier.py

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,110 @@ def commit_or_truncate(
216216
self._record_peak_kv()
217217
self._assert_cache_invariant_1()
218218

219+
def path_select(self, prompt: List[int]) -> "PathPlan":
220+
"""Select between continuation and new-session paths for ``prompt``.
221+
222+
Implements the deterministic two-path selection from ADR 0007
223+
§2.4. Returns ``ContinuationPlan`` when the new prompt extends
224+
the cached state monotonically (§2.4.a), or ``NewSession``
225+
otherwise (§2.4.b).
226+
227+
This is **not** a fallback. Both paths are first-class correct
228+
actions for their respective input classes.
229+
230+
Asserts ADR 0007 §2.9 INV-2 (position monotonicity within a
231+
session): a ``ContinuationPlan`` always has
232+
``skip_n == self.next_global_position``. Violation indicates a
233+
bug in the path-selection algorithm and raises rather than
234+
falling back.
235+
"""
236+
from .path_plan import ContinuationPlan, NewSession # avoid circular
237+
238+
if not prompt:
239+
raise ValueError("prompt must be non-empty")
240+
prompt_list = list(prompt)
241+
242+
# Cold start (§2.4.b case 1)
243+
if self.cache is None or self.cache_logical_size == 0:
244+
return NewSession(prompt=prompt_list)
245+
246+
cache_end = self.next_global_position
247+
248+
# Shorter history (§2.4.b case 2): the new prompt cannot extend
249+
# the cached state because it ends before the cache's logical
250+
# end. This is a different conversation from the client's side.
251+
if len(prompt_list) < cache_end:
252+
return NewSession(prompt=prompt_list)
253+
254+
# Diverging history (§2.4.b case 3): the cached tokens disagree
255+
# with the new prompt at any cached logical position.
256+
if not self._prompt_matches_cached_positions(prompt_list):
257+
return NewSession(prompt=prompt_list)
258+
259+
# Continuation precondition satisfied.
260+
skip_n = cache_end
261+
new_tokens = prompt_list[skip_n:]
262+
263+
# INV-2 (ADR 0007 §2.9): structural invariant — skip_n must
264+
# equal next_global_position. If it doesn't, the planning logic
265+
# above is buggy; surface as a critical error per §2.9.
266+
if skip_n != self.next_global_position:
267+
raise AssertionError(
268+
f"INV-2 violated (position monotonicity): planned "
269+
f"skip_n={skip_n} but next_global_position="
270+
f"{self.next_global_position}. Continuation must "
271+
f"extend exactly from the cache's logical end. This "
272+
f"is a bug in path_select; ADR 0007 §2.9 forbids "
273+
f"silent recovery. cache_logical_size="
274+
f"{self.cache_logical_size}, "
275+
f"cached_token_sequence_len="
276+
f"{len(self.cached_token_sequence)}."
277+
)
278+
279+
return ContinuationPlan(skip_n=skip_n, new_tokens=new_tokens)
280+
281+
@torch.no_grad()
282+
def prefill_incremental(self, new_tokens: List[int]) -> None:
283+
"""Run incremental prefill on ``new_tokens``, reusing cached state.
284+
285+
Counterpart to :meth:`prefill`: where ``prefill`` resets and
286+
rebuilds the cache from scratch, ``prefill_incremental`` keeps
287+
the existing cache and only forwards the new tokens. Used by
288+
the continuation path of ADR 0007 §2.4.a.
289+
290+
``new_tokens`` is the suffix returned by ``path_select`` in
291+
``ContinuationPlan.new_tokens``; the caller must have verified
292+
the continuation precondition before calling.
293+
294+
After this call, ``next_token_logits`` reflects the verifier's
295+
prediction for the token immediately after ``new_tokens[-1]``,
296+
same contract as :meth:`prefill`.
297+
298+
Edge case: if ``new_tokens`` is empty (the rare exact-match
299+
case where the new prompt equals the cached state in length
300+
and content), this is a no-op — ``next_token_logits`` is
301+
whatever the previous call left it as, which is still the
302+
correct prediction at that position.
303+
"""
304+
if self.cache is None:
305+
raise RuntimeError(
306+
"prefill_incremental called before any prefill; cache "
307+
"is None. Call path_select first and route NewSession "
308+
"to prefill() instead."
309+
)
310+
if not new_tokens:
311+
return # no-op; cache state is already correct
312+
# Forward the new tokens; treat all as accepted (this is prompt,
313+
# not speculative draft). forward_block + commit_or_truncate
314+
# already maintain cached_token_sequence and INV-1 in lockstep.
315+
block_logits = self.forward_block(list(new_tokens))
316+
self.commit_or_truncate(
317+
forwarded=len(new_tokens), accepted=len(new_tokens)
318+
)
319+
# next_token_logits = the prediction for the token AFTER the
320+
# last incrementally prefilled token.
321+
self.next_token_logits = block_logits[-1].clone()
322+
219323
@torch.no_grad()
220324
def append_token(self, token_id: int) -> torch.Tensor:
221325
"""Forward a single token (e.g., correction or bonus) into the cache.
@@ -286,6 +390,60 @@ def _truncate_tail_in_place(self, drop: int) -> None:
286390
layer.keys = keys[:, :, :keep, :].contiguous()
287391
layer.values = values[:, :, :keep, :].contiguous()
288392

393+
def _cached_global_positions(self) -> List[int]:
394+
"""Return the global token positions currently held in the cache.
395+
396+
The cache holds positions as a sink+window window over global
397+
positions. Concretely:
398+
399+
- If ``next_global_position <= sink + window``: positions
400+
``[0, 1, ..., next_global_position - 1]`` (everything
401+
still fits, no eviction yet).
402+
- Otherwise: positions ``[0..sink-1]`` (the sink prefix) +
403+
``[next_global_position - window..next_global_position -
404+
1]`` (the sliding window).
405+
406+
Length always equals ``len(self.cached_token_sequence)``
407+
(which equals the K/V cache seq dim, by INV-1).
408+
"""
409+
n = self.next_global_position
410+
if n == 0:
411+
return []
412+
budget = self.config.sink_size + self.config.window_size
413+
if n <= budget:
414+
return list(range(n))
415+
sink_positions = list(range(self.config.sink_size))
416+
window_start = n - self.config.window_size
417+
window_positions = list(range(window_start, n))
418+
return sink_positions + window_positions
419+
420+
def _prompt_matches_cached_positions(self, prompt: List[int]) -> bool:
421+
"""Token-id-level check for ADR 0007 §2.4.a.2.
422+
423+
Returns True iff, for every global position ``p`` currently
424+
held in the cache, ``prompt[p]`` equals the cached token id
425+
at the matching slot. False on any mismatch (caller routes
426+
to NewSession path).
427+
"""
428+
positions = self._cached_global_positions()
429+
if len(positions) != len(self.cached_token_sequence):
430+
# INV-1 should make this impossible; if it triggers,
431+
# it's a critical bug in the cache-mutation layer.
432+
raise AssertionError(
433+
f"_prompt_matches_cached_positions: position list of "
434+
f"length {len(positions)} disagrees with parallel "
435+
f"sequence of length {len(self.cached_token_sequence)}; "
436+
f"INV-1 should have caught this earlier"
437+
)
438+
for cache_idx, global_pos in enumerate(positions):
439+
if global_pos >= len(prompt):
440+
return False
441+
if int(prompt[global_pos]) != int(
442+
self.cached_token_sequence[cache_idx]
443+
):
444+
return False
445+
return True
446+
289447
def _cache_seq_length(self) -> int:
290448
"""Return the seq dim of the cache K/V tensors, or 0 if empty.
291449

0 commit comments

Comments
 (0)