Commit 651305b
PR-K1.C: per-attention-layer K/V preparation primitive
Third foundational PR of the K-series implementing ADR 0008 §11.7
phase K1. Builds on K1.A (capture, PR #70 merged) and K1.B (merge,
PR #71 to be merged). Stacked on top of #71.
This PR implements the 'attention-internal' primitive: given a Gemma3-
style attention layer's locally-computed K and V tensors at all
positions (post-norm post-RoPE for K, raw for V) plus a K/V capture
from the dLM proposer at evicted positions (pre-norm pre-RoPE per
K1.A's contract), produce the merged K/V tensors that the verifier
attention should consume as if the verifier had run full attention.
This is the load-bearing piece between K1.A's capture and K1.D's
end-to-end verifier integration. K1.D will monkey-patch
Gemma3Attention.forward to call this primitive right before
attention_interface(...).
Files:
inference_engine/v04/restored_attention.py (288 lines)
* _rotate_half(x) — standard RoPE half-rotation
(split last dim, swap halves with sign flip on second). Local
implementation rather than HF dependency so Linux unit tests
are completely offline. Math is bit-identical to HF transformers'
rotate_half; cross-checked at integration in K1.D.
* apply_rope_to_k_at_positions(k, cos, sin) — apply RoPE to K-only.
HF's apply_rotary_pos_emb rotates Q and K together; we only
rotate K because v0.4 K/V Restoration injects K/V at evicted
positions while Q is the verifier's standard query. Validates
rank/shape; raises on mismatch (no silent fallback per
ADR 0008 §6.2).
* slice_position_embeddings(cos, sin, positions) — index_select
wrapper with the same position-list contract as kv_merge
(sorted, deduped, in range).
* prepare_restored_attention_kv(...) — main entry point. Three
steps:
1. Apply k_norm to captured K (V skips norm).
2. Apply RoPE to captured K with cos/sin sliced at evicted
positions (transpose layout for the rope call, transpose
back for merge).
3. Merge into K_local / V_local at evicted positions via
K1.B's merge_kv_at_evicted_positions. Final layout is
[B, num_kv_heads, T, head_dim] — the post-norm post-RoPE
shape attention_interface consumes.
Empty evicted list returns K_local / V_local clones (identity
case). Validates shape consistency. ADR 0008 §6.2: no silent
fallback.
Why apply RoPE here and not at K1.A capture time:
* RoPE is verifier-step-position-dependent. K1.A runs as part
of the proposer's forward, where the verifier's per-step
position_embeddings are not yet known.
* In K2 / K3 the cross-model projection f_theta will sit between
captured K (proposer space) and merged K (verifier space).
f_theta runs cleanly in the pre-RoPE space because RoPE is
verifier-specific.
* In same-model identity case (K1), the captured K passed
through verifier's k_norm + apply_rope is bit-exact what the
verifier-with-full-attention would produce at that position.
ADR 0008 §11.5 §'Five properties' item 2 ('intelligence
approximates full attention') becomes a constructive equality
in this case. Cross-checked in K1.D Mac M4 reviewer.
inference_engine/v04/__init__.py
Public API re-export. After this PR public surface covers
K1.A (capture), K1.B (merge), K1.C (per-layer K/V prep);
K1.D will add a DLMRestoredVerifier wrapper class on top.
tests/inference_engine/v04/test_restored_attention.py (444 lines,
32 cases, all <0.10 s on Linux CI)
Test classes:
* TestRotateHalf — half-rotation primitive; double-rotate-negates
invariant; rank/dtype preservation.
* TestApplyRopeToKAtPositions — RoPE math: identity when
cos=1/sin=0; pure rotation when cos=0/sin=1; hand-rolled scalar
reference for T=1 head_dim=4; multi-head broadcasting; rank/
shape/compatibility raises.
* TestSlicePositionEmbeddings — index_select wrapper with
position-list contract (sorted, deduped, in range, non-empty);
rank/shape raises.
* TestPrepareRestoredAttentionKV (12 cases) — end-to-end on
synthetic k_norm + manual cos/sin: empty evicted is identity;
shape preservation; non-evicted positions preserve K_local
bit-exactly; evicted K equals normed-then-roped captured (ref
against manual apply_rope on transposed captured); evicted V
equals captured V (no norm, no RoPE); k_norm scaling is actually
applied (2x scale gives 2x evicted-position output); shape
validation raises (K_local rank 3, V_local mismatch, captured
rank 3); gradient flows through captured K/V; bf16 dtype
preservation; consecutive evicted block (the common case from
compute_evicted_positions).
Combined v04 test status after this PR:
tests/inference_engine/v04/ has 103 cases (32 K1.A + 39 K1.B +
32 K1.C), all <0.15 s on Linux CI, no HF model download.
What's next:
K1.D — DLMRestoredVerifier wrapper that ties capture (K1.A) +
merge (K1.B) + per-layer prep (this PR) together via a
monkey-patch on Gemma3Attention.forward, then runs end-
to-end NIAH validation on Mac M4 against:
* full-attention oracle baseline (target ~100% recall)
* v0.3 sink+window=4+64 (measured 16.7% in
sink_window_quality_ab_1780714635.json)
* v0.4 verifier with K/V Restoration enabled (ADR 0008
§11.8 gate (a) target: >= 95% mid-context recall at
100k context).
Stacking notes:
Base branch (logical) is #71 K1.B; tooling reasons make us set
base_branch=main. After #71 is merged into main, this PR's diff
shrinks to just the K1.C additions.
Co-authored-by: FluffyAIcode <FluffyAIcode@users.noreply.github.com>1 parent 6e80430 commit 651305b
3 files changed
Lines changed: 913 additions & 0 deletions
File tree
- inference_engine/v04
- tests/inference_engine/v04
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
37 | 37 | | |
38 | 38 | | |
39 | 39 | | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
40 | 45 | | |
41 | 46 | | |
42 | 47 | | |
| |||
46 | 51 | | |
47 | 52 | | |
48 | 53 | | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
49 | 58 | | |
0 commit comments