@@ -216,6 +216,110 @@ def commit_or_truncate(
216216 self ._record_peak_kv ()
217217 self ._assert_cache_invariant_1 ()
218218
219+ def path_select (self , prompt : List [int ]) -> "PathPlan" :
220+ """Select between continuation and new-session paths for ``prompt``.
221+
222+ Implements the deterministic two-path selection from ADR 0007
223+ §2.4. Returns ``ContinuationPlan`` when the new prompt extends
224+ the cached state monotonically (§2.4.a), or ``NewSession``
225+ otherwise (§2.4.b).
226+
227+ This is **not** a fallback. Both paths are first-class correct
228+ actions for their respective input classes.
229+
230+ Asserts ADR 0007 §2.9 INV-2 (position monotonicity within a
231+ session): a ``ContinuationPlan`` always has
232+ ``skip_n == self.next_global_position``. Violation indicates a
233+ bug in the path-selection algorithm and raises rather than
234+ falling back.
235+ """
236+ from .path_plan import ContinuationPlan , NewSession # avoid circular
237+
238+ if not prompt :
239+ raise ValueError ("prompt must be non-empty" )
240+ prompt_list = list (prompt )
241+
242+ # Cold start (§2.4.b case 1)
243+ if self .cache is None or self .cache_logical_size == 0 :
244+ return NewSession (prompt = prompt_list )
245+
246+ cache_end = self .next_global_position
247+
248+ # Shorter history (§2.4.b case 2): the new prompt cannot extend
249+ # the cached state because it ends before the cache's logical
250+ # end. This is a different conversation from the client's side.
251+ if len (prompt_list ) < cache_end :
252+ return NewSession (prompt = prompt_list )
253+
254+ # Diverging history (§2.4.b case 3): the cached tokens disagree
255+ # with the new prompt at any cached logical position.
256+ if not self ._prompt_matches_cached_positions (prompt_list ):
257+ return NewSession (prompt = prompt_list )
258+
259+ # Continuation precondition satisfied.
260+ skip_n = cache_end
261+ new_tokens = prompt_list [skip_n :]
262+
263+ # INV-2 (ADR 0007 §2.9): structural invariant — skip_n must
264+ # equal next_global_position. If it doesn't, the planning logic
265+ # above is buggy; surface as a critical error per §2.9.
266+ if skip_n != self .next_global_position :
267+ raise AssertionError (
268+ f"INV-2 violated (position monotonicity): planned "
269+ f"skip_n={ skip_n } but next_global_position="
270+ f"{ self .next_global_position } . Continuation must "
271+ f"extend exactly from the cache's logical end. This "
272+ f"is a bug in path_select; ADR 0007 §2.9 forbids "
273+ f"silent recovery. cache_logical_size="
274+ f"{ self .cache_logical_size } , "
275+ f"cached_token_sequence_len="
276+ f"{ len (self .cached_token_sequence )} ."
277+ )
278+
279+ return ContinuationPlan (skip_n = skip_n , new_tokens = new_tokens )
280+
281+ @torch .no_grad ()
282+ def prefill_incremental (self , new_tokens : List [int ]) -> None :
283+ """Run incremental prefill on ``new_tokens``, reusing cached state.
284+
285+ Counterpart to :meth:`prefill`: where ``prefill`` resets and
286+ rebuilds the cache from scratch, ``prefill_incremental`` keeps
287+ the existing cache and only forwards the new tokens. Used by
288+ the continuation path of ADR 0007 §2.4.a.
289+
290+ ``new_tokens`` is the suffix returned by ``path_select`` in
291+ ``ContinuationPlan.new_tokens``; the caller must have verified
292+ the continuation precondition before calling.
293+
294+ After this call, ``next_token_logits`` reflects the verifier's
295+ prediction for the token immediately after ``new_tokens[-1]``,
296+ same contract as :meth:`prefill`.
297+
298+ Edge case: if ``new_tokens`` is empty (the rare exact-match
299+ case where the new prompt equals the cached state in length
300+ and content), this is a no-op — ``next_token_logits`` is
301+ whatever the previous call left it as, which is still the
302+ correct prediction at that position.
303+ """
304+ if self .cache is None :
305+ raise RuntimeError (
306+ "prefill_incremental called before any prefill; cache "
307+ "is None. Call path_select first and route NewSession "
308+ "to prefill() instead."
309+ )
310+ if not new_tokens :
311+ return # no-op; cache state is already correct
312+ # Forward the new tokens; treat all as accepted (this is prompt,
313+ # not speculative draft). forward_block + commit_or_truncate
314+ # already maintain cached_token_sequence and INV-1 in lockstep.
315+ block_logits = self .forward_block (list (new_tokens ))
316+ self .commit_or_truncate (
317+ forwarded = len (new_tokens ), accepted = len (new_tokens )
318+ )
319+ # next_token_logits = the prediction for the token AFTER the
320+ # last incrementally prefilled token.
321+ self .next_token_logits = block_logits [- 1 ].clone ()
322+
219323 @torch .no_grad ()
220324 def append_token (self , token_id : int ) -> torch .Tensor :
221325 """Forward a single token (e.g., correction or bonus) into the cache.
@@ -286,6 +390,60 @@ def _truncate_tail_in_place(self, drop: int) -> None:
286390 layer .keys = keys [:, :, :keep , :].contiguous ()
287391 layer .values = values [:, :, :keep , :].contiguous ()
288392
393+ def _cached_global_positions (self ) -> List [int ]:
394+ """Return the global token positions currently held in the cache.
395+
396+ The cache holds positions as a sink+window window over global
397+ positions. Concretely:
398+
399+ - If ``next_global_position <= sink + window``: positions
400+ ``[0, 1, ..., next_global_position - 1]`` (everything
401+ still fits, no eviction yet).
402+ - Otherwise: positions ``[0..sink-1]`` (the sink prefix) +
403+ ``[next_global_position - window..next_global_position -
404+ 1]`` (the sliding window).
405+
406+ Length always equals ``len(self.cached_token_sequence)``
407+ (which equals the K/V cache seq dim, by INV-1).
408+ """
409+ n = self .next_global_position
410+ if n == 0 :
411+ return []
412+ budget = self .config .sink_size + self .config .window_size
413+ if n <= budget :
414+ return list (range (n ))
415+ sink_positions = list (range (self .config .sink_size ))
416+ window_start = n - self .config .window_size
417+ window_positions = list (range (window_start , n ))
418+ return sink_positions + window_positions
419+
420+ def _prompt_matches_cached_positions (self , prompt : List [int ]) -> bool :
421+ """Token-id-level check for ADR 0007 §2.4.a.2.
422+
423+ Returns True iff, for every global position ``p`` currently
424+ held in the cache, ``prompt[p]`` equals the cached token id
425+ at the matching slot. False on any mismatch (caller routes
426+ to NewSession path).
427+ """
428+ positions = self ._cached_global_positions ()
429+ if len (positions ) != len (self .cached_token_sequence ):
430+ # INV-1 should make this impossible; if it triggers,
431+ # it's a critical bug in the cache-mutation layer.
432+ raise AssertionError (
433+ f"_prompt_matches_cached_positions: position list of "
434+ f"length { len (positions )} disagrees with parallel "
435+ f"sequence of length { len (self .cached_token_sequence )} ; "
436+ f"INV-1 should have caught this earlier"
437+ )
438+ for cache_idx , global_pos in enumerate (positions ):
439+ if global_pos >= len (prompt ):
440+ return False
441+ if int (prompt [global_pos ]) != int (
442+ self .cached_token_sequence [cache_idx ]
443+ ):
444+ return False
445+ return True
446+
289447 def _cache_seq_length (self ) -> int :
290448 """Return the seq dim of the cache K/V tensors, or 0 if empty.
291449
0 commit comments