Commit ab9915d
PR-K1.A: K/V capture infrastructure for v0.4 dLM K/V Restoration
ADR 0008 §11.5 v0.4 GA architecture states the dLM proposer's parallel
forward computes K, V at every position transiently and discards them
afterwards. v0.4 K/V Restoration uses these as a constant-memory
reconstruction source for the verifier's evicted-position cache slots
(at compute time only — they are never permanently stored).
This is the first foundational PR (K1.A) of the K-series implementing
ADR 0008 §11.7 phase K1 (same-model toy on Gemma 3-1B). Subsequent
PRs in the series:
K1.B — K/V injection: take a captured proposer KV at evicted
positions and feed them into a verifier forward as the K/V
the verifier attends to outside its sink+window.
K1.C — DLMRestoredVerifier wrapper: orchestrates capture + injection
+ sink+window cache management for end-to-end inference.
K1.D — NIAH validation harness on Mac M4 against the v0.3 sink+window
baseline and the full-attention oracle.
This PR (K1.A) implements only the capture half. Files:
inference_engine/v04/__init__.py
Public API surface; thin re-export of capture entry points. The
production-style end-to-end class (DLMRestoredVerifier) lands in
K1.C.
inference_engine/v04/kv_capture.py (497 lines)
* KVCapture dataclass with shape/dtype/device invariants enforced
at construction. Stores per-layer [B, T, num_kv_heads, head_dim]
tensors detached from the proposer's autograd graph.
* KVCapture.select_positions(positions) — slicing for the K1.B
injection step (extract evicted-position K/V).
* register_kv_capture_hooks(model, layer_indices=None) — primitive
that hooks the k_proj / v_proj Linear modules of every (or
selected) decoder layer's self-attention sub-module. Returns
(k_acc, v_acc, handles); caller is responsible for handle.remove.
* _locate_attention_layers — supports HF Gemma3 / Llama / Qwen /
Mistral (.model.layers[*].self_attn) and GPT-2 family
(.transformer.h[*].attn). Raises RuntimeError on unrecognised
shape (no silent fallback per ADR 0008 §6.2).
* capture_proposer_kv — high-level entry point. Runs a single
forward, harvests K/V, returns KVCapture. Decorated
@torch.no_grad() because the proposer is frozen by ADR 0008
§11.5 design; differentiable path through the capture is
available via register_kv_capture_hooks directly.
Capture point rationale (documented in module docstring):
K and V are captured pre-norm and pre-RoPE — i.e., the raw
k_proj / v_proj outputs. The injection layer (K1.B) is responsible
for re-applying RoPE for the verifier's target query position. This
is chosen over post-RoPE capture because (a) it gives a stable hook
point on a clean nn.Linear module that doesn't depend on
Gemma3Attention.forward internals, (b) the K2 cross-model projection
f_θ runs more naturally in the pre-RoPE space, and (c) same-model
identity round-trip is bit-exact under attn_implementation='eager'.
tests/inference_engine/v04/test_kv_capture.py (440 lines, 32 cases,
all <0.10 s on Linux CI)
Test classes:
* TestLocateAttentionLayers — Gemma/Llama, GPT-2, unrecognised, zero-
layer raise paths.
* TestRegisterKVCaptureHooks — hook installation/removal lifecycle,
layer_indices subset selection, dedup/sort, out-of-range raises,
missing k_proj/v_proj raises.
* TestCaptureProposerKVShapes — shape (1, T, num_kv_heads, head_dim);
detachment from autograd; dtype consistency under fp64.
* TestCaptureProposerKVValues — value correctness: captured layer-0
K and V match a manual k_proj(embed_tokens(input_ids)) reference
bit-exactly; capture is deterministic under fixed seed.
* TestCaptureProposerKVConfigInference — explicit num_kv_heads /
head_dim overrides take precedence; fallback derivation when
config has no .head_dim; raises on inconsistent shape rather than
silently reshaping.
* TestKVCaptureSelectPositions — index correctness against
index_select; raises on unsorted, duplicated, empty, negative,
and beyond-seqlen position lists.
* TestKVCaptureInvariants — dataclass __post_init__ rejects empty,
length-mismatched, shape-inconsistent, dtype-mismatched, and
rank-3 inputs.
Tests use synthetic mini-models (_MiniGemmaShapeModel, _MiniGPT2-
ShapeModel) that mirror the HF hook surface without HF transformers
download — so Linux CI runs in <0.1s and the value-correctness
assertions are bit-exact.
Out of scope (deliberate):
* Real Gemma 3-1B integration smoke — lives in K1.D Mac M4 reviewer.
* RoPE re-application — K1.B (injection step).
* Verifier KV cache merging — K1.B + K1.C.
* NIAH validation — K1.D.
* Cross-model f_θ projection — K2 (separate ADR 0008 §11.7 phase).
Stacked on top of PR #69 (ADR 0008 v0.4 amendment); merge order:
#69 first (architecture document), then this PR.
Co-authored-by: FluffyAIcode <FluffyAIcode@users.noreply.github.com>1 parent 8c8a67d commit ab9915d
4 files changed
Lines changed: 1181 additions & 0 deletions
File tree
- inference_engine/v04
- tests/inference_engine/v04
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
| 2 | + | |
| 3 | + | |
| 4 | + | |
| 5 | + | |
| 6 | + | |
| 7 | + | |
| 8 | + | |
| 9 | + | |
| 10 | + | |
| 11 | + | |
| 12 | + | |
| 13 | + | |
| 14 | + | |
| 15 | + | |
| 16 | + | |
| 17 | + | |
| 18 | + | |
| 19 | + | |
| 20 | + | |
| 21 | + | |
| 22 | + | |
| 23 | + | |
| 24 | + | |
| 25 | + | |
| 26 | + | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
0 commit comments