Skip to content

Commit ab9915d

Browse files
PR-K1.A: K/V capture infrastructure for v0.4 dLM K/V Restoration
ADR 0008 §11.5 v0.4 GA architecture states the dLM proposer's parallel forward computes K, V at every position transiently and discards them afterwards. v0.4 K/V Restoration uses these as a constant-memory reconstruction source for the verifier's evicted-position cache slots (at compute time only — they are never permanently stored). This is the first foundational PR (K1.A) of the K-series implementing ADR 0008 §11.7 phase K1 (same-model toy on Gemma 3-1B). Subsequent PRs in the series: K1.B — K/V injection: take a captured proposer KV at evicted positions and feed them into a verifier forward as the K/V the verifier attends to outside its sink+window. K1.C — DLMRestoredVerifier wrapper: orchestrates capture + injection + sink+window cache management for end-to-end inference. K1.D — NIAH validation harness on Mac M4 against the v0.3 sink+window baseline and the full-attention oracle. This PR (K1.A) implements only the capture half. Files: inference_engine/v04/__init__.py Public API surface; thin re-export of capture entry points. The production-style end-to-end class (DLMRestoredVerifier) lands in K1.C. inference_engine/v04/kv_capture.py (497 lines) * KVCapture dataclass with shape/dtype/device invariants enforced at construction. Stores per-layer [B, T, num_kv_heads, head_dim] tensors detached from the proposer's autograd graph. * KVCapture.select_positions(positions) — slicing for the K1.B injection step (extract evicted-position K/V). * register_kv_capture_hooks(model, layer_indices=None) — primitive that hooks the k_proj / v_proj Linear modules of every (or selected) decoder layer's self-attention sub-module. Returns (k_acc, v_acc, handles); caller is responsible for handle.remove. * _locate_attention_layers — supports HF Gemma3 / Llama / Qwen / Mistral (.model.layers[*].self_attn) and GPT-2 family (.transformer.h[*].attn). Raises RuntimeError on unrecognised shape (no silent fallback per ADR 0008 §6.2). * capture_proposer_kv — high-level entry point. Runs a single forward, harvests K/V, returns KVCapture. Decorated @torch.no_grad() because the proposer is frozen by ADR 0008 §11.5 design; differentiable path through the capture is available via register_kv_capture_hooks directly. Capture point rationale (documented in module docstring): K and V are captured pre-norm and pre-RoPE — i.e., the raw k_proj / v_proj outputs. The injection layer (K1.B) is responsible for re-applying RoPE for the verifier's target query position. This is chosen over post-RoPE capture because (a) it gives a stable hook point on a clean nn.Linear module that doesn't depend on Gemma3Attention.forward internals, (b) the K2 cross-model projection f_θ runs more naturally in the pre-RoPE space, and (c) same-model identity round-trip is bit-exact under attn_implementation='eager'. tests/inference_engine/v04/test_kv_capture.py (440 lines, 32 cases, all <0.10 s on Linux CI) Test classes: * TestLocateAttentionLayers — Gemma/Llama, GPT-2, unrecognised, zero- layer raise paths. * TestRegisterKVCaptureHooks — hook installation/removal lifecycle, layer_indices subset selection, dedup/sort, out-of-range raises, missing k_proj/v_proj raises. * TestCaptureProposerKVShapes — shape (1, T, num_kv_heads, head_dim); detachment from autograd; dtype consistency under fp64. * TestCaptureProposerKVValues — value correctness: captured layer-0 K and V match a manual k_proj(embed_tokens(input_ids)) reference bit-exactly; capture is deterministic under fixed seed. * TestCaptureProposerKVConfigInference — explicit num_kv_heads / head_dim overrides take precedence; fallback derivation when config has no .head_dim; raises on inconsistent shape rather than silently reshaping. * TestKVCaptureSelectPositions — index correctness against index_select; raises on unsorted, duplicated, empty, negative, and beyond-seqlen position lists. * TestKVCaptureInvariants — dataclass __post_init__ rejects empty, length-mismatched, shape-inconsistent, dtype-mismatched, and rank-3 inputs. Tests use synthetic mini-models (_MiniGemmaShapeModel, _MiniGPT2- ShapeModel) that mirror the HF hook surface without HF transformers download — so Linux CI runs in <0.1s and the value-correctness assertions are bit-exact. Out of scope (deliberate): * Real Gemma 3-1B integration smoke — lives in K1.D Mac M4 reviewer. * RoPE re-application — K1.B (injection step). * Verifier KV cache merging — K1.B + K1.C. * NIAH validation — K1.D. * Cross-model f_θ projection — K2 (separate ADR 0008 §11.7 phase). Stacked on top of PR #69 (ADR 0008 v0.4 amendment); merge order: #69 first (architecture document), then this PR. Co-authored-by: FluffyAIcode <FluffyAIcode@users.noreply.github.com>
1 parent 8c8a67d commit ab9915d

4 files changed

Lines changed: 1181 additions & 0 deletions

File tree

inference_engine/v04/__init__.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
"""Kakeya Inference Engine v0.4 architecture.
2+
3+
This subpackage implements the v0.4 GA design as specified in
4+
ADR 0008 §11 (the v0.4 amendment dated 2026-06-08): the verifier
5+
maintains a minimal sink+window KV cache, and at every generation
6+
step accepts transient K/V tensors at evicted positions reconstructed
7+
from the dLM proposer's parallel forward pass.
8+
9+
The architecture's load-bearing fact, recorded in ADR 0008 §11.3:
10+
the dLM proposer has no KV cache, so its K/V tensors at every
11+
position are computed transiently each forward and discarded. This
12+
makes the proposer a constant-memory K/V reconstruction source.
13+
14+
Implementation phases per ADR 0008 §11.7:
15+
16+
* **K1**: same-model toy (proposer and verifier share Gemma 3-1B
17+
weights). Implement K/V routing infrastructure. Validate on
18+
synthetic NIAH that recall ≈ oracle when the projection is
19+
identity.
20+
* **K2**: cross-model toy (proposer = Gemma 3-1B, verifier = Gemma
21+
3-4B). Train per-layer linear projection f_θ.
22+
* **K3**: production scale.
23+
* **K4**: KakeyaLattice composition.
24+
* **K5**: default flip + docs.
25+
26+
This `__init__.py` is intentionally a thin re-export layer. The
27+
production-style API (a `DLMRestoredVerifier` class wrapping the
28+
whole pipeline) lands in K1.C; K1.A / K1.B build the foundation.
29+
"""
30+
31+
from inference_engine.v04.kv_capture import (
32+
KVCapture,
33+
capture_proposer_kv,
34+
register_kv_capture_hooks,
35+
)
36+
37+
__all__ = [
38+
"KVCapture",
39+
"capture_proposer_kv",
40+
"register_kv_capture_hooks",
41+
]

0 commit comments

Comments
 (0)