Skip to content

Commit 651305b

Browse files
PR-K1.C: per-attention-layer K/V preparation primitive
Third foundational PR of the K-series implementing ADR 0008 §11.7 phase K1. Builds on K1.A (capture, PR #70 merged) and K1.B (merge, PR #71 to be merged). Stacked on top of #71. This PR implements the 'attention-internal' primitive: given a Gemma3- style attention layer's locally-computed K and V tensors at all positions (post-norm post-RoPE for K, raw for V) plus a K/V capture from the dLM proposer at evicted positions (pre-norm pre-RoPE per K1.A's contract), produce the merged K/V tensors that the verifier attention should consume as if the verifier had run full attention. This is the load-bearing piece between K1.A's capture and K1.D's end-to-end verifier integration. K1.D will monkey-patch Gemma3Attention.forward to call this primitive right before attention_interface(...). Files: inference_engine/v04/restored_attention.py (288 lines) * _rotate_half(x) — standard RoPE half-rotation (split last dim, swap halves with sign flip on second). Local implementation rather than HF dependency so Linux unit tests are completely offline. Math is bit-identical to HF transformers' rotate_half; cross-checked at integration in K1.D. * apply_rope_to_k_at_positions(k, cos, sin) — apply RoPE to K-only. HF's apply_rotary_pos_emb rotates Q and K together; we only rotate K because v0.4 K/V Restoration injects K/V at evicted positions while Q is the verifier's standard query. Validates rank/shape; raises on mismatch (no silent fallback per ADR 0008 §6.2). * slice_position_embeddings(cos, sin, positions) — index_select wrapper with the same position-list contract as kv_merge (sorted, deduped, in range). * prepare_restored_attention_kv(...) — main entry point. Three steps: 1. Apply k_norm to captured K (V skips norm). 2. Apply RoPE to captured K with cos/sin sliced at evicted positions (transpose layout for the rope call, transpose back for merge). 3. Merge into K_local / V_local at evicted positions via K1.B's merge_kv_at_evicted_positions. Final layout is [B, num_kv_heads, T, head_dim] — the post-norm post-RoPE shape attention_interface consumes. Empty evicted list returns K_local / V_local clones (identity case). Validates shape consistency. ADR 0008 §6.2: no silent fallback. Why apply RoPE here and not at K1.A capture time: * RoPE is verifier-step-position-dependent. K1.A runs as part of the proposer's forward, where the verifier's per-step position_embeddings are not yet known. * In K2 / K3 the cross-model projection f_theta will sit between captured K (proposer space) and merged K (verifier space). f_theta runs cleanly in the pre-RoPE space because RoPE is verifier-specific. * In same-model identity case (K1), the captured K passed through verifier's k_norm + apply_rope is bit-exact what the verifier-with-full-attention would produce at that position. ADR 0008 §11.5 §'Five properties' item 2 ('intelligence approximates full attention') becomes a constructive equality in this case. Cross-checked in K1.D Mac M4 reviewer. inference_engine/v04/__init__.py Public API re-export. After this PR public surface covers K1.A (capture), K1.B (merge), K1.C (per-layer K/V prep); K1.D will add a DLMRestoredVerifier wrapper class on top. tests/inference_engine/v04/test_restored_attention.py (444 lines, 32 cases, all <0.10 s on Linux CI) Test classes: * TestRotateHalf — half-rotation primitive; double-rotate-negates invariant; rank/dtype preservation. * TestApplyRopeToKAtPositions — RoPE math: identity when cos=1/sin=0; pure rotation when cos=0/sin=1; hand-rolled scalar reference for T=1 head_dim=4; multi-head broadcasting; rank/ shape/compatibility raises. * TestSlicePositionEmbeddings — index_select wrapper with position-list contract (sorted, deduped, in range, non-empty); rank/shape raises. * TestPrepareRestoredAttentionKV (12 cases) — end-to-end on synthetic k_norm + manual cos/sin: empty evicted is identity; shape preservation; non-evicted positions preserve K_local bit-exactly; evicted K equals normed-then-roped captured (ref against manual apply_rope on transposed captured); evicted V equals captured V (no norm, no RoPE); k_norm scaling is actually applied (2x scale gives 2x evicted-position output); shape validation raises (K_local rank 3, V_local mismatch, captured rank 3); gradient flows through captured K/V; bf16 dtype preservation; consecutive evicted block (the common case from compute_evicted_positions). Combined v04 test status after this PR: tests/inference_engine/v04/ has 103 cases (32 K1.A + 39 K1.B + 32 K1.C), all <0.15 s on Linux CI, no HF model download. What's next: K1.D — DLMRestoredVerifier wrapper that ties capture (K1.A) + merge (K1.B) + per-layer prep (this PR) together via a monkey-patch on Gemma3Attention.forward, then runs end- to-end NIAH validation on Mac M4 against: * full-attention oracle baseline (target ~100% recall) * v0.3 sink+window=4+64 (measured 16.7% in sink_window_quality_ab_1780714635.json) * v0.4 verifier with K/V Restoration enabled (ADR 0008 §11.8 gate (a) target: >= 95% mid-context recall at 100k context). Stacking notes: Base branch (logical) is #71 K1.B; tooling reasons make us set base_branch=main. After #71 is merged into main, this PR's diff shrinks to just the K1.C additions. Co-authored-by: FluffyAIcode <FluffyAIcode@users.noreply.github.com>
1 parent 6e80430 commit 651305b

3 files changed

Lines changed: 913 additions & 0 deletions

File tree

inference_engine/v04/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,11 @@
3737
compute_evicted_positions,
3838
merge_kv_at_evicted_positions,
3939
)
40+
from inference_engine.v04.restored_attention import (
41+
apply_rope_to_k_at_positions,
42+
prepare_restored_attention_kv,
43+
slice_position_embeddings,
44+
)
4045

4146
__all__ = [
4247
# K1.A — capture
@@ -46,4 +51,8 @@
4651
# K1.B — merge
4752
"compute_evicted_positions",
4853
"merge_kv_at_evicted_positions",
54+
# K1.C — restored attention K/V preparation
55+
"apply_rope_to_k_at_positions",
56+
"prepare_restored_attention_kv",
57+
"slice_position_embeddings",
4958
]

0 commit comments

Comments
 (0)