Skip to content

Commit 1831b5e

Browse files
PR-R1e: write-path expansion (alpha FFN / beta multi-layer / gamma full block)
R1d-beta produced the decisive write-path-bottleneck diagnosis after the needle-matcher bug was fixed: doubling localization (0.25 -> 0.50, aux_loss 1.62 -> 0.69) had ZERO effect on cross_attn_recall (0.10 vs 0.12 within 50-sample noise). Per ADR 0011 §10's decision matrix this is the (I-1) cell -- bridge can locate but cannot project. R1e tests three architecturally distinct write-path expansions in a single PR, all reusing R1d-beta's supervised aux loss (we know it lifts localization 'for free' and was decisive in R1d). Three variants, all opt-in via flags, all with zero step-0 contribution (so adding write capacity is a strict superset of the R1d regime; runs without the new flags reproduce R1d step-for-step): * R1e-alpha (--bridge-use-ffn-write-path): Bridge.__init__ adds an LN + 4x SiLU FFN with zero-init down_proj. forward() now does: out = o_proj(cross_attn(h, bank)) if FFN: out = out + down_proj(silu(up_proj(LN(out)))) Tests 'is more write capacity (4x hidden width + nonlinearity) enough?'. * R1e-beta (--cross-attn-depths '8,14,20'): CrossAttentionVerifier now accepts bridges={depth: bridge} dict in addition to the legacy single-bridge signature. Each depth registers an independent forward hook on layers[depth-1]; hooks fire in increasing-depth order and each adds its own delta to the residual stream. Capture_attention now populates a _last_attention_weights_by_depth dict in addition to the single-bridge alias (the deepest bridge for back-compat). Tests 'is having only one chance to write the bottleneck?'. * R1e-gamma (--bridge-use-block-architecture): Replaces the bridge with a full pre-norm transformer block: h_norm1 = input_norm(h) attn_delta = o_proj(cross_attn(h_norm1, bank)) h_after_attn = h + attn_delta ffn_delta = down_proj(silu(up_proj(LN(h_after_attn)))) delta = attn_delta + ffn_delta Auto-implies --bridge-use-ffn-write-path. Tests 'is what's needed a complete transformer block (norm + nonlinearity + width), not just capacity?'. Other changes: * CLI: --cross-attn-depths (CSV, multi-bridge), --bridge-use-ffn- write-path, --bridge-use-block-architecture, --ffn-expansion (4). * Report schema_version 4 -> 5: cross_attn_depths is now a list; bridge_use_ffn_write_path / bridge_use_block_architecture / ffn_expansion / n_trainable_params config fields. Consumers keyed on v4 must handle the list form (defaults to a single-element list when only --cross-attn-depth was used) plus the three new keys. * CrossAttentionVerifier signature change: 'bridges' kwarg added. Specifying both 'bridges' and (cross_attn + cross_attn_depth) raises ValueError (ADR 0008 §6.2 no silent fallback). Single- bridge legacy signature preserves bit-identical behavior; all 65 R1d tests still pass without modification. Tests (Linux CI, no GPU, no HF download, 83 cases all <0.25 s total): * TestBridgeFFNWritePath (4 cases): FFN modules only when flag set; zero-init down_proj makes step-0 output identical to no-FFN bridge (R1d invariant preserved); breaking zero-init makes outputs diverge; up/down_proj/ffn_norm all receive gradient. * TestBridgeBlockArchitecture (4 cases): block flag implies FFN flag (auto-promotion); strict-zero-init step-0 gives exact-zero block delta; attn weights returned correctly; input_norm + attn_post_norm receive gradient. * TestMultiBridgeVerifier (8 cases): legacy single-bridge signature unchanged; bridges dict creates correct structure; specifying both APIs raises; specifying neither raises; invalid depth raises; multi-bridge runs without crashing; all hooks removed after forward including on exception; per-depth attention captured correctly when capture_attention=True. Reviewer script (scripts/review_pr_r1e_on_vast.sh) launches the three variants. SEQUENTIAL by default because R1d's first attempt hit CUBLAS_STATUS_ALLOC_FAILED running 3 capture-attention forwards in parallel on a single H200; R1e-beta's 3 simultaneous attention captures + retained activations would be even worse. PARALLEL=1 is available for users on >80GB GPUs. Per-variant skip flags (SKIP_ALPHA / SKIP_BETA / SKIP_GAMMA) for cost control. Out of scope: * No production engine integration (research-track). * No change to v0.3 GA tag, runtime, or SDKs. * No update to ADR 0011 status — that's gated on R1e empirical outcome and will land in a follow-up PR. Co-authored-by: FluffyAIcode <FluffyAIcode@users.noreply.github.com>
1 parent 2757f0b commit 1831b5e

3 files changed

Lines changed: 839 additions & 89 deletions

File tree

0 commit comments

Comments
 (0)