You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Add rlformers forward-pass features to ExecuTorch backbone for on-device export parity (#19096)
Summary:
The 730M dense model checkpoint uses several features that the ExecuTorch XNNPACK export path did not implement. Without these, the exported model produces numerically incorrect output.
This diff adds support for 8 missing features:
1. `normalize_tok_embeddings` — scaleless RMSNorm after embedding lookup
2. `qk_norm_before_rope` — conversion from GenAI args (attention code already supported it)
3. `scale_query_by` — custom scalar multiplier on Q after QK norm
4. `use_attn_o_gate` — sigmoid gate on attention output using a learned linear projection of the layer input
5. `use_attn_o_norm` — scaleless per-head RMSNorm on attention output (applied before o_gate)
6. `use_residual_gate` — NormPreservingResidualConnection with learned per-dim gates for both attention and FFN residual connections
7. `use_ffn_learnable_scales` — RMSNormWithInputScale replacing standard post-FFN norm, computing `rms_norm(gamma * x)` instead of `gamma * rms_norm(x)`
8. `output_soft_cap_temp` — `tanh(logits/temp) * temp` soft capping on output logits
Additionally, this diff fixes a QK norm checkpoint compatibility issue: some checkpoints contain learned QK norm weights even though their `params.json` has `qk_norm_affine=False` (due to default changes after training). The ET model was creating `ScalelessRMSNorm` (no weight parameter) based on `params.json`, silently discarding the checkpoint's trained QK norm weights. The rlformers reference model loaded them correctly, causing ~53-67 dB SNR divergence. The fix peeks at the checkpoint state dict before model construction — if QK norm weights are present, `qk_norm_affine` is overridden to `True` so the ET model creates affine QK norms that load those weights.
All features are off by default (backward compatible). They activate when the corresponding fields are set in the checkpoint's params.json and propagated through model_args_conversion.
Weight key mappings added for: `attention.og.weight`, `add_attn.gate`, `add_ffn.gate`, `post_ffn_norm.weight`.
Reviewed By: chinnadhurai, digantdesai
Differential Revision: D102030169
0 commit comments