Skip to content

Commit 529abcf

Browse files
authored
Merge pull request #71 from FluffyAIcode/AgentMemory/v04-pr-k1b-kv-merge-8e7f
PR-K1.B: K/V merge primitive for v0.4 dLM K/V Restoration
2 parents 8dea55f + 6e80430 commit 529abcf

3 files changed

Lines changed: 773 additions & 0 deletions

File tree

inference_engine/v04/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,17 @@
3333
capture_proposer_kv,
3434
register_kv_capture_hooks,
3535
)
36+
from inference_engine.v04.kv_merge import (
37+
compute_evicted_positions,
38+
merge_kv_at_evicted_positions,
39+
)
3640

3741
__all__ = [
42+
# K1.A — capture
3843
"KVCapture",
3944
"capture_proposer_kv",
4045
"register_kv_capture_hooks",
46+
# K1.B — merge
47+
"compute_evicted_positions",
48+
"merge_kv_at_evicted_positions",
4149
]

inference_engine/v04/kv_merge.py

Lines changed: 309 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,309 @@
1+
"""Merge verifier's locally-computed K/V with captured proposer K/V at
2+
evicted positions.
3+
4+
ADR 0008 §11.5 — at every attention layer in the v0.4 verifier, the
5+
attention input K/V is the union of:
6+
7+
* **K_local / V_local** — the verifier's own K/V projections of the
8+
positions still in its sink+window cache (computed normally during
9+
the verifier's forward pass).
10+
* **K_captured / V_captured** — K/V at positions the verifier has
11+
evicted, reconstructed from the dLM proposer's parallel forward
12+
via the K1.A capture machinery.
13+
14+
This module implements that union as a pure tensor operation. It is
15+
**RoPE-agnostic**: it merges raw projection outputs at any consistent
16+
RoPE state (both pre-RoPE, or both post-RoPE — but never one of
17+
each). The caller is responsible for applying RoPE consistently to
18+
both branches before or after the merge.
19+
20+
For K1.B we merge **pre-RoPE** because that's the form K1.A captures
21+
in (see ``inference_engine/v04/kv_capture.py`` module docstring §
22+
"Why pre-RoPE rather than post-RoPE"). K1.C will apply RoPE inside
23+
the verifier's standard attention forward, after the merge, using
24+
HF's own ``apply_rotary_pos_emb`` so we don't duplicate that
25+
machinery.
26+
27+
API contract
28+
------------
29+
30+
The single public entry point :func:`merge_kv_at_evicted_positions`
31+
takes ``[B, T, num_kv_heads, head_dim]`` local tensors and
32+
``[B, len(evicted), num_kv_heads, head_dim]`` captured tensors plus
33+
the list of evicted positions, and returns ``[B, T, num_kv_heads,
34+
head_dim]`` merged tensors with K/V at evicted positions replaced
35+
by the captured values.
36+
37+
Empty evicted list is the no-op identity (returns a clone of the
38+
local tensors). Position lists are validated for sortedness, dedup,
39+
and range; the captured-tensor T-dim must equal ``len(positions)``;
40+
all shape/dtype/device must be consistent. Mismatches raise
41+
``ValueError`` per ADR 0008 §6.2 (no silent fallback).
42+
43+
The returned merged tensors are clones of the inputs — the caller
44+
can mutate them freely without affecting the input tensors. This
45+
costs an extra allocation per layer per step, which is fine in the
46+
v0.4 architecture where merge happens once per attention forward.
47+
The clone is needed because ``index_copy_`` is in-place; if we
48+
mutated K_local directly we would surprise callers who reuse it
49+
elsewhere.
50+
51+
Differentiability
52+
-----------------
53+
54+
Gradient flows through the captured branch (so a learnable cross-
55+
model projection ``f_θ`` in K2/K3 can be trained end-to-end through
56+
the merge). Gradient flows through the local branch only at
57+
non-evicted positions; at evicted positions the local values are
58+
overwritten and contribute no gradient. This is a deliberate
59+
boundary condition matching the v0.4 architecture: at evicted
60+
positions the verifier's local representation is irrelevant by
61+
design.
62+
"""
63+
64+
from __future__ import annotations
65+
66+
from typing import List, Sequence, Tuple
67+
68+
import torch
69+
70+
71+
def _validate_positions(
72+
positions: Sequence[int],
73+
seq_len: int,
74+
) -> List[int]:
75+
"""Validate and return the sorted-deduped list of positions.
76+
77+
Raises ``ValueError`` on:
78+
* unsorted input
79+
* duplicates
80+
* any position < 0
81+
* any position >= ``seq_len``
82+
"""
83+
if not positions:
84+
return []
85+
positions_list = list(positions)
86+
sorted_positions = sorted(set(positions_list))
87+
if sorted_positions != positions_list:
88+
raise ValueError(
89+
"evicted_positions must be sorted ascending with no "
90+
f"duplicates; got {positions_list}"
91+
)
92+
if sorted_positions[0] < 0 or sorted_positions[-1] >= seq_len:
93+
raise ValueError(
94+
f"evicted_positions must lie in [0, {seq_len}); "
95+
f"got [{sorted_positions[0]}, {sorted_positions[-1]}]"
96+
)
97+
return sorted_positions
98+
99+
100+
def _validate_shapes(
101+
K_local: torch.Tensor,
102+
V_local: torch.Tensor,
103+
K_captured: torch.Tensor,
104+
V_captured: torch.Tensor,
105+
n_evicted: int,
106+
) -> None:
107+
"""Validate the four tensors share consistent batch / head / dim
108+
structure and that captured tensors' T-dim equals ``n_evicted``.
109+
110+
Raises ``ValueError`` on any mismatch.
111+
"""
112+
if K_local.shape != V_local.shape:
113+
raise ValueError(
114+
f"K_local shape {tuple(K_local.shape)} != V_local shape "
115+
f"{tuple(V_local.shape)}"
116+
)
117+
if K_captured.shape != V_captured.shape:
118+
raise ValueError(
119+
f"K_captured shape {tuple(K_captured.shape)} != V_captured "
120+
f"shape {tuple(V_captured.shape)}"
121+
)
122+
if K_local.dim() != 4:
123+
raise ValueError(
124+
"K_local must be 4-D [B, T, num_kv_heads, head_dim]; got "
125+
f"shape {tuple(K_local.shape)}"
126+
)
127+
if K_captured.dim() != 4:
128+
raise ValueError(
129+
"K_captured must be 4-D [B, n_evicted, num_kv_heads, "
130+
f"head_dim]; got shape {tuple(K_captured.shape)}"
131+
)
132+
133+
B_local, T_local, H_local, D_local = K_local.shape
134+
B_cap, T_cap, H_cap, D_cap = K_captured.shape
135+
if B_local != B_cap:
136+
raise ValueError(
137+
f"batch mismatch: K_local B={B_local} K_captured B={B_cap}"
138+
)
139+
if H_local != H_cap:
140+
raise ValueError(
141+
f"num_kv_heads mismatch: K_local H={H_local} K_captured H={H_cap}"
142+
)
143+
if D_local != D_cap:
144+
raise ValueError(
145+
f"head_dim mismatch: K_local D={D_local} K_captured D={D_cap}"
146+
)
147+
if T_cap != n_evicted:
148+
raise ValueError(
149+
f"K_captured T-dim {T_cap} != len(evicted_positions) {n_evicted}"
150+
)
151+
152+
if K_local.dtype != K_captured.dtype:
153+
raise ValueError(
154+
f"dtype mismatch: K_local {K_local.dtype} K_captured "
155+
f"{K_captured.dtype}"
156+
)
157+
if K_local.device != K_captured.device:
158+
raise ValueError(
159+
f"device mismatch: K_local {K_local.device} K_captured "
160+
f"{K_captured.device}"
161+
)
162+
163+
164+
def merge_kv_at_evicted_positions(
165+
K_local: torch.Tensor,
166+
V_local: torch.Tensor,
167+
K_captured: torch.Tensor,
168+
V_captured: torch.Tensor,
169+
evicted_positions: Sequence[int],
170+
) -> Tuple[torch.Tensor, torch.Tensor]:
171+
"""Return ``(K_merged, V_merged)`` where evicted positions are
172+
overridden by the captured proposer K/V and all other positions
173+
keep the verifier's local K/V.
174+
175+
Parameters
176+
----------
177+
K_local
178+
Verifier's K projection at every position. Shape
179+
``[B, T, num_kv_heads, head_dim]``.
180+
V_local
181+
Verifier's V projection at every position. Same shape as
182+
``K_local``.
183+
K_captured
184+
Proposer's K projection at the evicted positions only,
185+
already sliced via :meth:`KVCapture.select_positions`. Shape
186+
``[B, len(evicted_positions), num_kv_heads, head_dim]``.
187+
V_captured
188+
Same as ``K_captured`` for V.
189+
evicted_positions
190+
Sorted-ascending list of positions in ``[0, T)`` whose K/V
191+
come from the captured branch. Empty list is the no-op
192+
identity case (returns clones of the local tensors). Per
193+
ADR 0008 §6.2, unsorted / duplicated / out-of-range inputs
194+
raise rather than silently coerce.
195+
196+
Returns
197+
-------
198+
A tuple ``(K_merged, V_merged)`` of shape ``[B, T, num_kv_heads,
199+
head_dim]``. The returned tensors are clones — mutating them does
200+
not affect the inputs.
201+
202+
Notes
203+
-----
204+
Both inputs and outputs are RoPE-agnostic; the caller must apply
205+
RoPE consistently to both branches before or after the merge.
206+
K1.B uses the merge in pre-RoPE; K1.C will apply RoPE inside the
207+
verifier's standard attention forward (after the merge), reusing
208+
HF's ``apply_rotary_pos_emb``.
209+
210+
Gradient flows through ``K_captured`` / ``V_captured`` for the
211+
evicted positions; through ``K_local`` / ``V_local`` for the
212+
other positions. The local branch's gradient at evicted positions
213+
is severed by the override (those tensors are discarded by the
214+
merge). This is the intentional v0.4 boundary: at evicted
215+
positions, the verifier's local representation is irrelevant by
216+
design.
217+
"""
218+
# Rank check first so K_local.size(1) is meaningfully the T dim.
219+
# We can't run the full shape validation yet because empty
220+
# evicted_positions is the no-op identity case where captured
221+
# tensors are allowed to be empty.
222+
if K_local.dim() != 4:
223+
raise ValueError(
224+
"K_local must be 4-D [B, T, num_kv_heads, head_dim]; got "
225+
f"shape {tuple(K_local.shape)}"
226+
)
227+
228+
sorted_positions = _validate_positions(evicted_positions, K_local.size(1))
229+
230+
if not sorted_positions:
231+
# No evictions: identity. Clone so callers can mutate freely.
232+
return K_local.clone(), V_local.clone()
233+
234+
_validate_shapes(
235+
K_local, V_local, K_captured, V_captured, len(sorted_positions),
236+
)
237+
238+
idx = torch.tensor(sorted_positions, device=K_local.device, dtype=torch.long)
239+
240+
K_merged = K_local.clone()
241+
V_merged = V_local.clone()
242+
K_merged.index_copy_(dim=1, index=idx, source=K_captured)
243+
V_merged.index_copy_(dim=1, index=idx, source=V_captured)
244+
return K_merged, V_merged
245+
246+
247+
def compute_evicted_positions(
248+
seq_len: int,
249+
sink_size: int,
250+
window_size: int,
251+
) -> List[int]:
252+
"""Return the list of token positions that fall **outside** the
253+
sink+window range over a sequence of length ``seq_len``.
254+
255+
The v0.4 verifier's permanent KV cache holds K/V at the union
256+
``{0, 1, ..., sink-1} ∪ {seq_len-window, ..., seq_len-1}``. All
257+
other positions are "evicted" — their K/V are reconstructed from
258+
the dLM proposer's transient forward each step. This helper
259+
materialises the evicted list once per generation step so callers
260+
can pass it to :meth:`KVCapture.select_positions` and to
261+
:func:`merge_kv_at_evicted_positions` without recomputing.
262+
263+
Parameters
264+
----------
265+
seq_len
266+
Total number of token positions in the current attention
267+
view (prompt + drafts so far).
268+
sink_size
269+
Number of attention sinks at the head of the sequence (ADR
270+
0001 + ADR 0008 §2.3 v0.3 default: 4).
271+
window_size
272+
Width of the trailing sliding window (ADR 0001 + ADR 0008
273+
§2.3 v0.3 default: 64).
274+
275+
Returns
276+
-------
277+
Sorted-ascending list of evicted position indices. Empty when
278+
``seq_len <= sink_size + window_size`` (everything fits in the
279+
cache, no evictions needed). Always contiguous: positions
280+
``[sink_size, seq_len - window_size)``.
281+
282+
Raises
283+
------
284+
ValueError
285+
If any of ``seq_len``, ``sink_size``, ``window_size`` is
286+
negative.
287+
288+
Notes
289+
-----
290+
Position ranges:
291+
292+
* ``[0, sink_size)`` — sink (kept in cache, NOT evicted)
293+
* ``[sink_size, seq_len - window_size)`` — middle (EVICTED)
294+
* ``[seq_len - window_size, seq_len)`` — window (kept in cache)
295+
296+
When sink and window overlap (``sink_size + window_size >=
297+
seq_len``), nothing is evicted. The function returns ``[]`` and
298+
the v0.4 architecture degenerates to standard full-attention
299+
inference at that step.
300+
"""
301+
if seq_len < 0 or sink_size < 0 or window_size < 0:
302+
raise ValueError(
303+
f"seq_len={seq_len}, sink_size={sink_size}, "
304+
f"window_size={window_size} must all be non-negative"
305+
)
306+
if seq_len <= sink_size + window_size:
307+
# Sink + window covers the whole sequence; nothing to evict.
308+
return []
309+
return list(range(sink_size, seq_len - window_size))

0 commit comments

Comments
 (0)