Commit 1831b5e
PR-R1e: write-path expansion (alpha FFN / beta multi-layer / gamma full block)
R1d-beta produced the decisive write-path-bottleneck diagnosis after
the needle-matcher bug was fixed: doubling localization (0.25 -> 0.50,
aux_loss 1.62 -> 0.69) had ZERO effect on cross_attn_recall (0.10 vs
0.12 within 50-sample noise). Per ADR 0011 §10's decision matrix this
is the (I-1) cell -- bridge can locate but cannot project. R1e tests
three architecturally distinct write-path expansions in a single PR,
all reusing R1d-beta's supervised aux loss (we know it lifts
localization 'for free' and was decisive in R1d).
Three variants, all opt-in via flags, all with zero step-0
contribution (so adding write capacity is a strict superset of the
R1d regime; runs without the new flags reproduce R1d step-for-step):
* R1e-alpha (--bridge-use-ffn-write-path):
Bridge.__init__ adds an LN + 4x SiLU FFN with zero-init down_proj.
forward() now does:
out = o_proj(cross_attn(h, bank))
if FFN: out = out + down_proj(silu(up_proj(LN(out))))
Tests 'is more write capacity (4x hidden width + nonlinearity)
enough?'.
* R1e-beta (--cross-attn-depths '8,14,20'):
CrossAttentionVerifier now accepts bridges={depth: bridge} dict
in addition to the legacy single-bridge signature. Each depth
registers an independent forward hook on layers[depth-1]; hooks
fire in increasing-depth order and each adds its own delta to
the residual stream. Capture_attention now populates a
_last_attention_weights_by_depth dict in addition to the
single-bridge alias (the deepest bridge for back-compat).
Tests 'is having only one chance to write the bottleneck?'.
* R1e-gamma (--bridge-use-block-architecture):
Replaces the bridge with a full pre-norm transformer block:
h_norm1 = input_norm(h)
attn_delta = o_proj(cross_attn(h_norm1, bank))
h_after_attn = h + attn_delta
ffn_delta = down_proj(silu(up_proj(LN(h_after_attn))))
delta = attn_delta + ffn_delta
Auto-implies --bridge-use-ffn-write-path. Tests 'is what's
needed a complete transformer block (norm + nonlinearity +
width), not just capacity?'.
Other changes:
* CLI: --cross-attn-depths (CSV, multi-bridge), --bridge-use-ffn-
write-path, --bridge-use-block-architecture, --ffn-expansion (4).
* Report schema_version 4 -> 5: cross_attn_depths is now a list;
bridge_use_ffn_write_path / bridge_use_block_architecture /
ffn_expansion / n_trainable_params config fields. Consumers keyed
on v4 must handle the list form (defaults to a single-element list
when only --cross-attn-depth was used) plus the three new keys.
* CrossAttentionVerifier signature change: 'bridges' kwarg added.
Specifying both 'bridges' and (cross_attn + cross_attn_depth)
raises ValueError (ADR 0008 §6.2 no silent fallback). Single-
bridge legacy signature preserves bit-identical behavior; all 65
R1d tests still pass without modification.
Tests (Linux CI, no GPU, no HF download, 83 cases all <0.25 s total):
* TestBridgeFFNWritePath (4 cases): FFN modules only when flag set;
zero-init down_proj makes step-0 output identical to no-FFN bridge
(R1d invariant preserved); breaking zero-init makes outputs
diverge; up/down_proj/ffn_norm all receive gradient.
* TestBridgeBlockArchitecture (4 cases): block flag implies FFN flag
(auto-promotion); strict-zero-init step-0 gives exact-zero block
delta; attn weights returned correctly; input_norm + attn_post_norm
receive gradient.
* TestMultiBridgeVerifier (8 cases): legacy single-bridge signature
unchanged; bridges dict creates correct structure; specifying both
APIs raises; specifying neither raises; invalid depth raises;
multi-bridge runs without crashing; all hooks removed after
forward including on exception; per-depth attention captured
correctly when capture_attention=True.
Reviewer script (scripts/review_pr_r1e_on_vast.sh) launches the
three variants. SEQUENTIAL by default because R1d's first attempt
hit CUBLAS_STATUS_ALLOC_FAILED running 3 capture-attention forwards
in parallel on a single H200; R1e-beta's 3 simultaneous attention
captures + retained activations would be even worse. PARALLEL=1 is
available for users on >80GB GPUs. Per-variant skip flags
(SKIP_ALPHA / SKIP_BETA / SKIP_GAMMA) for cost control.
Out of scope:
* No production engine integration (research-track).
* No change to v0.3 GA tag, runtime, or SDKs.
* No update to ADR 0011 status — that's gated on R1e empirical
outcome and will land in a follow-up PR.
Co-authored-by: FluffyAIcode <FluffyAIcode@users.noreply.github.com>1 parent 2757f0b commit 1831b5e
3 files changed
Lines changed: 839 additions & 89 deletions
File tree
- scripts
- research
- tests/research
0 commit comments