Skip to content

Add rlformers forward-pass features to ExecuTorch backbone for on-device export parity (#19096)#19096

Merged
meta-codesync[bot] merged 1 commit intopytorch:mainfrom
ifed-ucsd:export-D102030169
Apr 30, 2026
Merged

Add rlformers forward-pass features to ExecuTorch backbone for on-device export parity (#19096)#19096
meta-codesync[bot] merged 1 commit intopytorch:mainfrom
ifed-ucsd:export-D102030169

Conversation

@ifed-ucsd
Copy link
Copy Markdown
Contributor

@ifed-ucsd ifed-ucsd commented Apr 23, 2026

Summary:

The 730M dense model checkpoint uses several features that the ExecuTorch XNNPACK export path did not implement. Without these, the exported model produces numerically incorrect output.

This diff adds support for 8 missing features:

  1. normalize_tok_embeddings — scaleless RMSNorm after embedding lookup
  2. qk_norm_before_rope — conversion from GenAI args (attention code already supported it)
  3. scale_query_by — custom scalar multiplier on Q after QK norm
  4. use_attn_o_gate — sigmoid gate on attention output using a learned linear projection of the layer input
  5. use_attn_o_norm — scaleless per-head RMSNorm on attention output (applied before o_gate)
  6. use_residual_gate — NormPreservingResidualConnection with learned per-dim gates for both attention and FFN residual connections
  7. use_ffn_learnable_scales — RMSNormWithInputScale replacing standard post-FFN norm, computing rms_norm(gamma * x) instead of gamma * rms_norm(x)
  8. output_soft_cap_temptanh(logits/temp) * temp soft capping on output logits

Additionally, this diff fixes a QK norm checkpoint compatibility issue: some checkpoints contain learned QK norm weights even though their params.json has qk_norm_affine=False (due to default changes after training). The ET model was creating ScalelessRMSNorm (no weight parameter) based on params.json, silently discarding the checkpoint's trained QK norm weights. The rlformers reference model loaded them correctly, causing ~53-67 dB SNR divergence. The fix peeks at the checkpoint state dict before model construction — if QK norm weights are present, qk_norm_affine is overridden to True so the ET model creates affine QK norms that load those weights.

All features are off by default (backward compatible). They activate when the corresponding fields are set in the checkpoint's params.json and propagated through model_args_conversion.

Weight key mappings added for: attention.og.weight, add_attn.gate, add_ffn.gate, post_ffn_norm.weight.

Reviewed By: chinnadhurai, digantdesai

Differential Revision: D102030169

@ifed-ucsd ifed-ucsd requested a review from lucylq as a code owner April 23, 2026 21:53
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Apr 23, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/19096

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ You can merge normally! (2 Unrelated Failures)

As of commit 3d5f0d6 with merge base d9688da (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

BROKEN TRUNK - The following job failed but was present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 23, 2026
@meta-codesync
Copy link
Copy Markdown
Contributor

meta-codesync Bot commented Apr 23, 2026

@ifed-ucsd has exported this pull request. If you are a Meta employee, you can view the originating Diff in D102030169.

@github-actions
Copy link
Copy Markdown

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@meta-codesync meta-codesync Bot changed the title Add rlformers forward-pass features to ExecuTorch backbone for on-device export parity Add rlformers forward-pass features to ExecuTorch backbone for on-device export parity (#19096) Apr 23, 2026
ifed-ucsd added a commit to ifed-ucsd/executorch that referenced this pull request Apr 23, 2026
…ice export parity (pytorch#19096)

Summary:

The 730M dense model checkpoint uses several rlformers features that the ExecuTorch XNNPACK export path did not implement. Without these, the exported model produces numerically incorrect output.

This diff adds support for 8 missing features:
1. `normalize_tok_embeddings` — scaleless RMSNorm after embedding lookup
2. `qk_norm_before_rope` — conversion from GenAI args (attention code already supported it)
3. `scale_query_by` — custom scalar multiplier on Q after QK norm
4. `use_attn_o_gate` — sigmoid gate on attention output using a learned linear projection of the layer input
5. `use_attn_o_norm` — scaleless per-head RMSNorm on attention output (applied before o_gate)
6. `use_residual_gate` — NormPreservingResidualConnection with learned per-dim gates for both attention and FFN residual connections
7. `use_ffn_learnable_scales` — RMSNormWithInputScale replacing standard post-FFN norm, computing `rms_norm(gamma * x)` instead of `gamma * rms_norm(x)`
8. `output_soft_cap_temp` — `tanh(logits/temp) * temp` soft capping on output logits

All features are off by default (backward compatible). They activate when the corresponding fields are set in the checkpoint's params.json and propagated through model_args_conversion.

Weight key mappings added for: `attention.og.weight`, `add_attn.gate`, `add_ffn.gate`, `post_ffn_norm.weight`.

Differential Revision: D102030169
ifed-ucsd added a commit to ifed-ucsd/executorch that referenced this pull request Apr 23, 2026
…ice export parity (pytorch#19096)

Summary:

The 730M dense model checkpoint uses several rlformers features that the ExecuTorch XNNPACK export path did not implement. Without these, the exported model produces numerically incorrect output.

This diff adds support for 8 missing features:
1. `normalize_tok_embeddings` — scaleless RMSNorm after embedding lookup
2. `qk_norm_before_rope` — conversion from GenAI args (attention code already supported it)
3. `scale_query_by` — custom scalar multiplier on Q after QK norm
4. `use_attn_o_gate` — sigmoid gate on attention output using a learned linear projection of the layer input
5. `use_attn_o_norm` — scaleless per-head RMSNorm on attention output (applied before o_gate)
6. `use_residual_gate` — NormPreservingResidualConnection with learned per-dim gates for both attention and FFN residual connections
7. `use_ffn_learnable_scales` — RMSNormWithInputScale replacing standard post-FFN norm, computing `rms_norm(gamma * x)` instead of `gamma * rms_norm(x)`
8. `output_soft_cap_temp` — `tanh(logits/temp) * temp` soft capping on output logits

All features are off by default (backward compatible). They activate when the corresponding fields are set in the checkpoint's params.json and propagated through model_args_conversion.

Weight key mappings added for: `attention.og.weight`, `add_attn.gate`, `add_ffn.gate`, `post_ffn_norm.weight`.

Differential Revision: D102030169
ifed-ucsd added a commit to ifed-ucsd/executorch that referenced this pull request Apr 23, 2026
…ice export parity (pytorch#19096)

Summary:
Pull Request resolved: pytorch#19096

The 730M dense model checkpoint uses several rlformers features that the ExecuTorch XNNPACK export path did not implement. Without these, the exported model produces numerically incorrect output.

This diff adds support for 8 missing features:
1. `normalize_tok_embeddings` — scaleless RMSNorm after embedding lookup
2. `qk_norm_before_rope` — conversion from GenAI args (attention code already supported it)
3. `scale_query_by` — custom scalar multiplier on Q after QK norm
4. `use_attn_o_gate` — sigmoid gate on attention output using a learned linear projection of the layer input
5. `use_attn_o_norm` — scaleless per-head RMSNorm on attention output (applied before o_gate)
6. `use_residual_gate` — NormPreservingResidualConnection with learned per-dim gates for both attention and FFN residual connections
7. `use_ffn_learnable_scales` — RMSNormWithInputScale replacing standard post-FFN norm, computing `rms_norm(gamma * x)` instead of `gamma * rms_norm(x)`
8. `output_soft_cap_temp` — `tanh(logits/temp) * temp` soft capping on output logits

All features are off by default (backward compatible). They activate when the corresponding fields are set in the checkpoint's params.json and propagated through model_args_conversion.

Weight key mappings added for: `attention.og.weight`, `add_attn.gate`, `add_ffn.gate`, `post_ffn_norm.weight`.

Differential Revision: D102030169
ifed-ucsd added a commit to ifed-ucsd/executorch that referenced this pull request Apr 23, 2026
…ice export parity (pytorch#19096)

Summary:

The 730M dense model checkpoint uses several rlformers features that the ExecuTorch XNNPACK export path did not implement. Without these, the exported model produces numerically incorrect output.

This diff adds support for 8 missing features:
1. `normalize_tok_embeddings` — scaleless RMSNorm after embedding lookup
2. `qk_norm_before_rope` — conversion from GenAI args (attention code already supported it)
3. `scale_query_by` — custom scalar multiplier on Q after QK norm
4. `use_attn_o_gate` — sigmoid gate on attention output using a learned linear projection of the layer input
5. `use_attn_o_norm` — scaleless per-head RMSNorm on attention output (applied before o_gate)
6. `use_residual_gate` — NormPreservingResidualConnection with learned per-dim gates for both attention and FFN residual connections
7. `use_ffn_learnable_scales` — RMSNormWithInputScale replacing standard post-FFN norm, computing `rms_norm(gamma * x)` instead of `gamma * rms_norm(x)`
8. `output_soft_cap_temp` — `tanh(logits/temp) * temp` soft capping on output logits

All features are off by default (backward compatible). They activate when the corresponding fields are set in the checkpoint's params.json and propagated through model_args_conversion.

Weight key mappings added for: `attention.og.weight`, `add_attn.gate`, `add_ffn.gate`, `post_ffn_norm.weight`.

Differential Revision: D102030169
Copy link
Copy Markdown
Contributor

@digantdesai digantdesai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review automatically exported from Phabricator review in Meta.

@meta-codesync meta-codesync Bot changed the title Add rlformers forward-pass features to ExecuTorch backbone for on-device export parity (#19096) Add rlformers forward-pass features to ExecuTorch backbone for on-device export parity Apr 28, 2026
@ifed-ucsd ifed-ucsd force-pushed the export-D102030169 branch 2 times, most recently from e686bed to 18e2b65 Compare April 28, 2026 16:46
@ifed-ucsd ifed-ucsd force-pushed the export-D102030169 branch 3 times, most recently from 5c8ace1 to 85c4f8c Compare April 29, 2026 18:59
…ice export parity (pytorch#19096)

Summary:

The 730M dense model checkpoint uses several features that the ExecuTorch XNNPACK export path did not implement. Without these, the exported model produces numerically incorrect output.

This diff adds support for 8 missing features:
1. `normalize_tok_embeddings` — scaleless RMSNorm after embedding lookup
2. `qk_norm_before_rope` — conversion from GenAI args (attention code already supported it)
3. `scale_query_by` — custom scalar multiplier on Q after QK norm
4. `use_attn_o_gate` — sigmoid gate on attention output using a learned linear projection of the layer input
5. `use_attn_o_norm` — scaleless per-head RMSNorm on attention output (applied before o_gate)
6. `use_residual_gate` — NormPreservingResidualConnection with learned per-dim gates for both attention and FFN residual connections
7. `use_ffn_learnable_scales` — RMSNormWithInputScale replacing standard post-FFN norm, computing `rms_norm(gamma * x)` instead of `gamma * rms_norm(x)`
8. `output_soft_cap_temp` — `tanh(logits/temp) * temp` soft capping on output logits

Additionally, this diff fixes a QK norm checkpoint compatibility issue: some checkpoints contain learned QK norm weights even though their `params.json` has `qk_norm_affine=False` (due to default changes after training). The ET model was creating `ScalelessRMSNorm` (no weight parameter) based on `params.json`, silently discarding the checkpoint's trained QK norm weights. The rlformers reference model loaded them correctly, causing ~53-67 dB SNR divergence. The fix peeks at the checkpoint state dict before model construction — if QK norm weights are present, `qk_norm_affine` is overridden to `True` so the ET model creates affine QK norms that load those weights.

All features are off by default (backward compatible). They activate when the corresponding fields are set in the checkpoint's params.json and propagated through model_args_conversion.

Weight key mappings added for: `attention.og.weight`, `add_attn.gate`, `add_ffn.gate`, `post_ffn_norm.weight`.

Reviewed By: chinnadhurai, digantdesai

Differential Revision: D102030169
@meta-codesync meta-codesync Bot changed the title Add rlformers forward-pass features to ExecuTorch backbone for on-device export parity Add rlformers forward-pass features to ExecuTorch backbone for on-device export parity (#19096) Apr 29, 2026
@meta-codesync meta-codesync Bot merged commit 8a97ac7 into pytorch:main Apr 30, 2026
173 of 176 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported meta-exported

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants