@@ -76,6 +76,23 @@ def __init__(self, config: Optional[VerifierConfig] = None) -> None:
7676 # Logits at `next_global_position` predicting the next token. Updated
7777 # after every forward pass.
7878 self .next_token_logits : Optional [torch .Tensor ] = None
79+ # Parallel record of the token id at every K/V cache slot, in the
80+ # same physical order as ``self.cache.layers[*].keys``. Maintained
81+ # synchronously with the K/V tensors by every cache mutation
82+ # method below. Required by ADR 0007 §2.2 + §2.9 INV-1: the
83+ # path-selection algorithm (PR 7-2) needs token-id-level prefix
84+ # matching against the cache, and the K/V tensors don't expose
85+ # token ids.
86+ #
87+ # Storage: at most ``sink_size + window_size`` int entries, so
88+ # bounded at the same constant the K/V cache is bounded at
89+ # (e.g. 68 entries × 8 bytes per Python int = 544 bytes,
90+ # negligible vs the 7.4 MiB K/V).
91+ #
92+ # Invariant INV-1 (ADR 0007 §2.9): after every cache mutation,
93+ # ``len(self.cached_token_sequence)`` equals the K/V tensor
94+ # sequence dimension. Enforced by ``_assert_cache_invariant_1``.
95+ self .cached_token_sequence : List [int ] = []
7996
8097 self .stats = VerifierStats (
8198 weight_bytes = sum (p .numel () * p .element_size () for p in self .model .parameters ())
@@ -87,6 +104,8 @@ def reset(self) -> None:
87104 self .cache_logical_size = 0
88105 self .next_global_position = 0
89106 self .next_token_logits = None
107+ self .cached_token_sequence = []
108+ self ._assert_cache_invariant_1 ()
90109
91110 @torch .no_grad ()
92111 def prefill (self , prompt_ids : List [int ]) -> None :
@@ -112,11 +131,19 @@ def prefill(self, prompt_ids: List[int]) -> None:
112131 self .next_global_position = L
113132 self .next_token_logits = outputs .logits [0 , - 1 ].clone ()
114133
134+ # Update parallel token sequence in lockstep with the K/V cache.
135+ # After this prefill the cache holds K/V for all L tokens; the
136+ # subsequent ``_trim_cache_in_place`` will drop middle entries
137+ # to enforce sink+window. We mirror that exact transformation
138+ # on ``cached_token_sequence``.
139+ self .cached_token_sequence = list (prompt_ids )
140+
115141 self ._record_peak_activation (outputs .logits )
116142 self ._trim_cache_in_place ()
117143 self ._record_peak_kv ()
118144 self .stats .forward_calls += 1
119145 self .stats .tokens_consumed += L
146+ self ._assert_cache_invariant_1 ()
120147
121148 @torch .no_grad ()
122149 def forward_block (self , tokens : List [int ]) -> torch .Tensor :
@@ -153,9 +180,13 @@ def forward_block(self, tokens: List[int]) -> torch.Tensor:
153180 self .cache = outputs .past_key_values
154181 # Cache provisionally has cache_start + L slots until commit/truncate.
155182 self .cache_logical_size = cache_start + L
183+ # Mirror the provisional extension on the parallel sequence;
184+ # commit_or_truncate will drop the unaccepted tail in lockstep.
185+ self .cached_token_sequence = self .cached_token_sequence + list (tokens )
156186 self ._record_peak_activation (outputs .logits )
157187 self .stats .forward_calls += 1
158188 self .stats .tokens_consumed += L
189+ self ._assert_cache_invariant_1 ()
159190 # Don't trim yet — caller decides how many tokens were accepted.
160191 return outputs .logits [0 ].clone () # [L, V]
161192
@@ -177,10 +208,13 @@ def commit_or_truncate(
177208 drop = forwarded - accepted
178209 if drop > 0 :
179210 self ._truncate_tail_in_place (drop )
211+ # Mirror the tail truncation on the parallel sequence.
212+ self .cached_token_sequence = self .cached_token_sequence [:- drop ]
180213 self .cache_logical_size -= drop
181214 self .next_global_position += accepted
182215 self ._trim_cache_in_place ()
183216 self ._record_peak_kv ()
217+ self ._assert_cache_invariant_1 ()
184218
185219 @torch .no_grad ()
186220 def append_token (self , token_id : int ) -> torch .Tensor :
@@ -226,6 +260,12 @@ def _trim_cache_in_place(self) -> None:
226260 # peak_kv_bytes would over-report.
227261 layer .keys = torch .cat ([sink_k , tail_k ], dim = 2 ).contiguous ()
228262 layer .values = torch .cat ([sink_v , tail_v ], dim = 2 ).contiguous ()
263+ # Mirror the same sink+window slice on the parallel token sequence
264+ # so cached_token_sequence stays in lockstep with the K/V tensors.
265+ self .cached_token_sequence = (
266+ self .cached_token_sequence [:sink ]
267+ + self .cached_token_sequence [- keep_window :]
268+ )
229269 self .cache_logical_size = budget
230270
231271 def _truncate_tail_in_place (self , drop : int ) -> None :
@@ -246,6 +286,44 @@ def _truncate_tail_in_place(self, drop: int) -> None:
246286 layer .keys = keys [:, :, :keep , :].contiguous ()
247287 layer .values = values [:, :, :keep , :].contiguous ()
248288
289+ def _cache_seq_length (self ) -> int :
290+ """Return the seq dim of the cache K/V tensors, or 0 if empty.
291+
292+ Reads from the first non-empty layer; all layers share the same
293+ seq dim by construction (every K/V mutation in this class
294+ applies the same shape transformation across all layers).
295+ """
296+ if self .cache is None :
297+ return 0
298+ for layer in self .cache .layers :
299+ keys = getattr (layer , "keys" , None )
300+ if keys is not None :
301+ return int (keys .shape [2 ])
302+ return 0
303+
304+ def _assert_cache_invariant_1 (self ) -> None :
305+ """ADR 0007 §2.9 INV-1: parallel-sequence consistency.
306+
307+ After every cache mutation, ``len(self.cached_token_sequence)``
308+ must equal the K/V tensor sequence dimension. Violation
309+ indicates a bug in the cache-mutation path; per ADR 0007 §2.9
310+ the implementation must raise — never silently recover, never
311+ fall back, never re-sync.
312+ """
313+ actual = len (self .cached_token_sequence )
314+ expected = self ._cache_seq_length ()
315+ if actual != expected :
316+ raise AssertionError (
317+ f"INV-1 violated (parallel-sequence consistency): "
318+ f"cached_token_sequence has { actual } entries but K/V "
319+ f"cache seq dim is { expected } . This is a bug in the "
320+ f"verifier's cache-mutation path; per ADR 0007 §2.9 it "
321+ f"must surface as a critical error rather than be "
322+ f"silently recovered. cache_logical_size="
323+ f"{ self .cache_logical_size } , "
324+ f"next_global_position={ self .next_global_position } ."
325+ )
326+
249327 def live_kv_bytes (self ) -> int :
250328 """Return the current size of the verifier's live KV cache in bytes.
251329
0 commit comments