|
| 1 | +"""Merge verifier's locally-computed K/V with captured proposer K/V at |
| 2 | +evicted positions. |
| 3 | +
|
| 4 | +ADR 0008 §11.5 — at every attention layer in the v0.4 verifier, the |
| 5 | +attention input K/V is the union of: |
| 6 | +
|
| 7 | +* **K_local / V_local** — the verifier's own K/V projections of the |
| 8 | + positions still in its sink+window cache (computed normally during |
| 9 | + the verifier's forward pass). |
| 10 | +* **K_captured / V_captured** — K/V at positions the verifier has |
| 11 | + evicted, reconstructed from the dLM proposer's parallel forward |
| 12 | + via the K1.A capture machinery. |
| 13 | +
|
| 14 | +This module implements that union as a pure tensor operation. It is |
| 15 | +**RoPE-agnostic**: it merges raw projection outputs at any consistent |
| 16 | +RoPE state (both pre-RoPE, or both post-RoPE — but never one of |
| 17 | +each). The caller is responsible for applying RoPE consistently to |
| 18 | +both branches before or after the merge. |
| 19 | +
|
| 20 | +For K1.B we merge **pre-RoPE** because that's the form K1.A captures |
| 21 | +in (see ``inference_engine/v04/kv_capture.py`` module docstring § |
| 22 | +"Why pre-RoPE rather than post-RoPE"). K1.C will apply RoPE inside |
| 23 | +the verifier's standard attention forward, after the merge, using |
| 24 | +HF's own ``apply_rotary_pos_emb`` so we don't duplicate that |
| 25 | +machinery. |
| 26 | +
|
| 27 | +API contract |
| 28 | +------------ |
| 29 | +
|
| 30 | +The single public entry point :func:`merge_kv_at_evicted_positions` |
| 31 | +takes ``[B, T, num_kv_heads, head_dim]`` local tensors and |
| 32 | +``[B, len(evicted), num_kv_heads, head_dim]`` captured tensors plus |
| 33 | +the list of evicted positions, and returns ``[B, T, num_kv_heads, |
| 34 | +head_dim]`` merged tensors with K/V at evicted positions replaced |
| 35 | +by the captured values. |
| 36 | +
|
| 37 | +Empty evicted list is the no-op identity (returns a clone of the |
| 38 | +local tensors). Position lists are validated for sortedness, dedup, |
| 39 | +and range; the captured-tensor T-dim must equal ``len(positions)``; |
| 40 | +all shape/dtype/device must be consistent. Mismatches raise |
| 41 | +``ValueError`` per ADR 0008 §6.2 (no silent fallback). |
| 42 | +
|
| 43 | +The returned merged tensors are clones of the inputs — the caller |
| 44 | +can mutate them freely without affecting the input tensors. This |
| 45 | +costs an extra allocation per layer per step, which is fine in the |
| 46 | +v0.4 architecture where merge happens once per attention forward. |
| 47 | +The clone is needed because ``index_copy_`` is in-place; if we |
| 48 | +mutated K_local directly we would surprise callers who reuse it |
| 49 | +elsewhere. |
| 50 | +
|
| 51 | +Differentiability |
| 52 | +----------------- |
| 53 | +
|
| 54 | +Gradient flows through the captured branch (so a learnable cross- |
| 55 | +model projection ``f_θ`` in K2/K3 can be trained end-to-end through |
| 56 | +the merge). Gradient flows through the local branch only at |
| 57 | +non-evicted positions; at evicted positions the local values are |
| 58 | +overwritten and contribute no gradient. This is a deliberate |
| 59 | +boundary condition matching the v0.4 architecture: at evicted |
| 60 | +positions the verifier's local representation is irrelevant by |
| 61 | +design. |
| 62 | +""" |
| 63 | + |
| 64 | +from __future__ import annotations |
| 65 | + |
| 66 | +from typing import List, Sequence, Tuple |
| 67 | + |
| 68 | +import torch |
| 69 | + |
| 70 | + |
| 71 | +def _validate_positions( |
| 72 | + positions: Sequence[int], |
| 73 | + seq_len: int, |
| 74 | +) -> List[int]: |
| 75 | + """Validate and return the sorted-deduped list of positions. |
| 76 | +
|
| 77 | + Raises ``ValueError`` on: |
| 78 | + * unsorted input |
| 79 | + * duplicates |
| 80 | + * any position < 0 |
| 81 | + * any position >= ``seq_len`` |
| 82 | + """ |
| 83 | + if not positions: |
| 84 | + return [] |
| 85 | + positions_list = list(positions) |
| 86 | + sorted_positions = sorted(set(positions_list)) |
| 87 | + if sorted_positions != positions_list: |
| 88 | + raise ValueError( |
| 89 | + "evicted_positions must be sorted ascending with no " |
| 90 | + f"duplicates; got {positions_list}" |
| 91 | + ) |
| 92 | + if sorted_positions[0] < 0 or sorted_positions[-1] >= seq_len: |
| 93 | + raise ValueError( |
| 94 | + f"evicted_positions must lie in [0, {seq_len}); " |
| 95 | + f"got [{sorted_positions[0]}, {sorted_positions[-1]}]" |
| 96 | + ) |
| 97 | + return sorted_positions |
| 98 | + |
| 99 | + |
| 100 | +def _validate_shapes( |
| 101 | + K_local: torch.Tensor, |
| 102 | + V_local: torch.Tensor, |
| 103 | + K_captured: torch.Tensor, |
| 104 | + V_captured: torch.Tensor, |
| 105 | + n_evicted: int, |
| 106 | +) -> None: |
| 107 | + """Validate the four tensors share consistent batch / head / dim |
| 108 | + structure and that captured tensors' T-dim equals ``n_evicted``. |
| 109 | +
|
| 110 | + Raises ``ValueError`` on any mismatch. |
| 111 | + """ |
| 112 | + if K_local.shape != V_local.shape: |
| 113 | + raise ValueError( |
| 114 | + f"K_local shape {tuple(K_local.shape)} != V_local shape " |
| 115 | + f"{tuple(V_local.shape)}" |
| 116 | + ) |
| 117 | + if K_captured.shape != V_captured.shape: |
| 118 | + raise ValueError( |
| 119 | + f"K_captured shape {tuple(K_captured.shape)} != V_captured " |
| 120 | + f"shape {tuple(V_captured.shape)}" |
| 121 | + ) |
| 122 | + if K_local.dim() != 4: |
| 123 | + raise ValueError( |
| 124 | + "K_local must be 4-D [B, T, num_kv_heads, head_dim]; got " |
| 125 | + f"shape {tuple(K_local.shape)}" |
| 126 | + ) |
| 127 | + if K_captured.dim() != 4: |
| 128 | + raise ValueError( |
| 129 | + "K_captured must be 4-D [B, n_evicted, num_kv_heads, " |
| 130 | + f"head_dim]; got shape {tuple(K_captured.shape)}" |
| 131 | + ) |
| 132 | + |
| 133 | + B_local, T_local, H_local, D_local = K_local.shape |
| 134 | + B_cap, T_cap, H_cap, D_cap = K_captured.shape |
| 135 | + if B_local != B_cap: |
| 136 | + raise ValueError( |
| 137 | + f"batch mismatch: K_local B={B_local} K_captured B={B_cap}" |
| 138 | + ) |
| 139 | + if H_local != H_cap: |
| 140 | + raise ValueError( |
| 141 | + f"num_kv_heads mismatch: K_local H={H_local} K_captured H={H_cap}" |
| 142 | + ) |
| 143 | + if D_local != D_cap: |
| 144 | + raise ValueError( |
| 145 | + f"head_dim mismatch: K_local D={D_local} K_captured D={D_cap}" |
| 146 | + ) |
| 147 | + if T_cap != n_evicted: |
| 148 | + raise ValueError( |
| 149 | + f"K_captured T-dim {T_cap} != len(evicted_positions) {n_evicted}" |
| 150 | + ) |
| 151 | + |
| 152 | + if K_local.dtype != K_captured.dtype: |
| 153 | + raise ValueError( |
| 154 | + f"dtype mismatch: K_local {K_local.dtype} K_captured " |
| 155 | + f"{K_captured.dtype}" |
| 156 | + ) |
| 157 | + if K_local.device != K_captured.device: |
| 158 | + raise ValueError( |
| 159 | + f"device mismatch: K_local {K_local.device} K_captured " |
| 160 | + f"{K_captured.device}" |
| 161 | + ) |
| 162 | + |
| 163 | + |
| 164 | +def merge_kv_at_evicted_positions( |
| 165 | + K_local: torch.Tensor, |
| 166 | + V_local: torch.Tensor, |
| 167 | + K_captured: torch.Tensor, |
| 168 | + V_captured: torch.Tensor, |
| 169 | + evicted_positions: Sequence[int], |
| 170 | +) -> Tuple[torch.Tensor, torch.Tensor]: |
| 171 | + """Return ``(K_merged, V_merged)`` where evicted positions are |
| 172 | + overridden by the captured proposer K/V and all other positions |
| 173 | + keep the verifier's local K/V. |
| 174 | +
|
| 175 | + Parameters |
| 176 | + ---------- |
| 177 | + K_local |
| 178 | + Verifier's K projection at every position. Shape |
| 179 | + ``[B, T, num_kv_heads, head_dim]``. |
| 180 | + V_local |
| 181 | + Verifier's V projection at every position. Same shape as |
| 182 | + ``K_local``. |
| 183 | + K_captured |
| 184 | + Proposer's K projection at the evicted positions only, |
| 185 | + already sliced via :meth:`KVCapture.select_positions`. Shape |
| 186 | + ``[B, len(evicted_positions), num_kv_heads, head_dim]``. |
| 187 | + V_captured |
| 188 | + Same as ``K_captured`` for V. |
| 189 | + evicted_positions |
| 190 | + Sorted-ascending list of positions in ``[0, T)`` whose K/V |
| 191 | + come from the captured branch. Empty list is the no-op |
| 192 | + identity case (returns clones of the local tensors). Per |
| 193 | + ADR 0008 §6.2, unsorted / duplicated / out-of-range inputs |
| 194 | + raise rather than silently coerce. |
| 195 | +
|
| 196 | + Returns |
| 197 | + ------- |
| 198 | + A tuple ``(K_merged, V_merged)`` of shape ``[B, T, num_kv_heads, |
| 199 | + head_dim]``. The returned tensors are clones — mutating them does |
| 200 | + not affect the inputs. |
| 201 | +
|
| 202 | + Notes |
| 203 | + ----- |
| 204 | + Both inputs and outputs are RoPE-agnostic; the caller must apply |
| 205 | + RoPE consistently to both branches before or after the merge. |
| 206 | + K1.B uses the merge in pre-RoPE; K1.C will apply RoPE inside the |
| 207 | + verifier's standard attention forward (after the merge), reusing |
| 208 | + HF's ``apply_rotary_pos_emb``. |
| 209 | +
|
| 210 | + Gradient flows through ``K_captured`` / ``V_captured`` for the |
| 211 | + evicted positions; through ``K_local`` / ``V_local`` for the |
| 212 | + other positions. The local branch's gradient at evicted positions |
| 213 | + is severed by the override (those tensors are discarded by the |
| 214 | + merge). This is the intentional v0.4 boundary: at evicted |
| 215 | + positions, the verifier's local representation is irrelevant by |
| 216 | + design. |
| 217 | + """ |
| 218 | + # Rank check first so K_local.size(1) is meaningfully the T dim. |
| 219 | + # We can't run the full shape validation yet because empty |
| 220 | + # evicted_positions is the no-op identity case where captured |
| 221 | + # tensors are allowed to be empty. |
| 222 | + if K_local.dim() != 4: |
| 223 | + raise ValueError( |
| 224 | + "K_local must be 4-D [B, T, num_kv_heads, head_dim]; got " |
| 225 | + f"shape {tuple(K_local.shape)}" |
| 226 | + ) |
| 227 | + |
| 228 | + sorted_positions = _validate_positions(evicted_positions, K_local.size(1)) |
| 229 | + |
| 230 | + if not sorted_positions: |
| 231 | + # No evictions: identity. Clone so callers can mutate freely. |
| 232 | + return K_local.clone(), V_local.clone() |
| 233 | + |
| 234 | + _validate_shapes( |
| 235 | + K_local, V_local, K_captured, V_captured, len(sorted_positions), |
| 236 | + ) |
| 237 | + |
| 238 | + idx = torch.tensor(sorted_positions, device=K_local.device, dtype=torch.long) |
| 239 | + |
| 240 | + K_merged = K_local.clone() |
| 241 | + V_merged = V_local.clone() |
| 242 | + K_merged.index_copy_(dim=1, index=idx, source=K_captured) |
| 243 | + V_merged.index_copy_(dim=1, index=idx, source=V_captured) |
| 244 | + return K_merged, V_merged |
| 245 | + |
| 246 | + |
| 247 | +def compute_evicted_positions( |
| 248 | + seq_len: int, |
| 249 | + sink_size: int, |
| 250 | + window_size: int, |
| 251 | +) -> List[int]: |
| 252 | + """Return the list of token positions that fall **outside** the |
| 253 | + sink+window range over a sequence of length ``seq_len``. |
| 254 | +
|
| 255 | + The v0.4 verifier's permanent KV cache holds K/V at the union |
| 256 | + ``{0, 1, ..., sink-1} ∪ {seq_len-window, ..., seq_len-1}``. All |
| 257 | + other positions are "evicted" — their K/V are reconstructed from |
| 258 | + the dLM proposer's transient forward each step. This helper |
| 259 | + materialises the evicted list once per generation step so callers |
| 260 | + can pass it to :meth:`KVCapture.select_positions` and to |
| 261 | + :func:`merge_kv_at_evicted_positions` without recomputing. |
| 262 | +
|
| 263 | + Parameters |
| 264 | + ---------- |
| 265 | + seq_len |
| 266 | + Total number of token positions in the current attention |
| 267 | + view (prompt + drafts so far). |
| 268 | + sink_size |
| 269 | + Number of attention sinks at the head of the sequence (ADR |
| 270 | + 0001 + ADR 0008 §2.3 v0.3 default: 4). |
| 271 | + window_size |
| 272 | + Width of the trailing sliding window (ADR 0001 + ADR 0008 |
| 273 | + §2.3 v0.3 default: 64). |
| 274 | +
|
| 275 | + Returns |
| 276 | + ------- |
| 277 | + Sorted-ascending list of evicted position indices. Empty when |
| 278 | + ``seq_len <= sink_size + window_size`` (everything fits in the |
| 279 | + cache, no evictions needed). Always contiguous: positions |
| 280 | + ``[sink_size, seq_len - window_size)``. |
| 281 | +
|
| 282 | + Raises |
| 283 | + ------ |
| 284 | + ValueError |
| 285 | + If any of ``seq_len``, ``sink_size``, ``window_size`` is |
| 286 | + negative. |
| 287 | +
|
| 288 | + Notes |
| 289 | + ----- |
| 290 | + Position ranges: |
| 291 | +
|
| 292 | + * ``[0, sink_size)`` — sink (kept in cache, NOT evicted) |
| 293 | + * ``[sink_size, seq_len - window_size)`` — middle (EVICTED) |
| 294 | + * ``[seq_len - window_size, seq_len)`` — window (kept in cache) |
| 295 | +
|
| 296 | + When sink and window overlap (``sink_size + window_size >= |
| 297 | + seq_len``), nothing is evicted. The function returns ``[]`` and |
| 298 | + the v0.4 architecture degenerates to standard full-attention |
| 299 | + inference at that step. |
| 300 | + """ |
| 301 | + if seq_len < 0 or sink_size < 0 or window_size < 0: |
| 302 | + raise ValueError( |
| 303 | + f"seq_len={seq_len}, sink_size={sink_size}, " |
| 304 | + f"window_size={window_size} must all be non-negative" |
| 305 | + ) |
| 306 | + if seq_len <= sink_size + window_size: |
| 307 | + # Sink + window covers the whole sequence; nothing to evict. |
| 308 | + return [] |
| 309 | + return list(range(sink_size, seq_len - window_size)) |
0 commit comments