Add CoreML-stable RMSNorm for llama eager paths (#19523)#19523
Add CoreML-stable RMSNorm for llama eager paths (#19523)#19523telgamal-1 wants to merge 1 commit into
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/19523
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ⏳ No Failures, 6 PendingAs of commit fc26253 with merge base d8e4ffd ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@telgamal-1 has exported this pull request. If you are a Meta employee, you can view the originating Diff in D104862210. |
This PR needs a
|
| eps (float, optional): Stored for API compatibility; ignored in the math. | ||
|
|
||
| Attributes: | ||
| eps (float): Stored for API compatibility; not consumed by `_norm`. |
There was a problem hiding this comment.
Can we assert eps is 0 rather than silently drop it?
There was a problem hiding this comment.
Added an explicit assert of eps is set to 0
cda18f8 to
b2acb39
Compare
Summary: The standard `RMSNorm` formulation `x * rsqrt(mean(x²)) * weight` is numerically unstable on CoreML/ANE because the explicit FP32 cast around the mean reduction is silently stripped from the lowered graph, leaving the squared sum to overflow in FP16. The ANE PTE then diverges from the eager reference even on checkpoints fine-tuned in BF16/FP16. This diff introduces `RMSNormCoreML` in `examples/models/llama/norm.py`. The module expresses the normalization as `x * sqrt(d) / vector_norm(x, dim=-1)` — `torch.linalg.vector_norm` keeps the reduction in a single op that survives CoreML lowering, so FP16 inference remains stable. To avoid `0 / 0 = NaN` on zero-padded positions (chunked prefill in `StaticAttentionIOManager` pads each chunk to `input_len` with zeros), the denominator is floored with `sqrt(dim * eps)`. This matches standard RMSNorm's `rsqrt(mean(x²) + eps)` semantics on a zero input and is large enough to survive fp16 — a plain `1e-6` underflows. Real (non-zero) tokens satisfy `vector_norm(x) >> sqrt(dim * eps)`, so the floor is a no-op on real positions. A new `use_coreml_norm: bool = False` field on `ModelArgs` opts into the new norm without disturbing existing models. When True, every llama-side norm site constructs `RMSNormCoreML`: - `llama_transformer.py`: `attention_norm`, `ffn_norm`, the final `self.norm` on `Transformer`. - `attention.py`: `q_norm_fn` / `k_norm_fn` in the affine QK-norm path, AND the `else` branch of `_init_qk_norms` (the scaleless / non-affine QK-norm path that the original landing missed). - `static_attention.py`: `q_norm` / `k_norm` in the scaleless path, propagated through `from_attention_mha` by detecting `rms_norm_class is RMSNormCoreML`. The QNN/HTP export path is untouched and continues to use `torch.nn.RMSNorm`. Differential Revision: D104862210
Summary: The standard `RMSNorm` formulation `x * rsqrt(mean(x²)) * weight` is numerically unstable on CoreML/ANE because the explicit FP32 cast around the mean reduction is silently stripped from the lowered graph, leaving the squared sum to overflow in FP16. The ANE PTE then diverges from the eager reference even on checkpoints fine-tuned in BF16/FP16. This diff introduces `RMSNormCoreML` in `examples/models/llama/norm.py`. The module expresses the normalization as `x * sqrt(d) / vector_norm(x, dim=-1)` — `torch.linalg.vector_norm` keeps the reduction in a single op that survives CoreML lowering, so FP16 inference remains stable. To avoid `0 / 0 = NaN` on zero-padded positions (chunked prefill in `StaticAttentionIOManager` pads each chunk to `input_len` with zeros), the denominator is floored with `sqrt(dim * eps)`. This matches standard RMSNorm's `rsqrt(mean(x²) + eps)` semantics on a zero input and is large enough to survive fp16 — a plain `1e-6` underflows. Real (non-zero) tokens satisfy `vector_norm(x) >> sqrt(dim * eps)`, so the floor is a no-op on real positions. A new `use_coreml_norm: bool = False` field on `ModelArgs` opts into the new norm without disturbing existing models. When True, every llama-side norm site constructs `RMSNormCoreML`: - `llama_transformer.py`: `attention_norm`, `ffn_norm`, the final `self.norm` on `Transformer`. - `attention.py`: `q_norm_fn` / `k_norm_fn` in the affine QK-norm path, AND the `else` branch of `_init_qk_norms` (the scaleless / non-affine QK-norm path that the original landing missed). - `static_attention.py`: `q_norm` / `k_norm` in the scaleless path, propagated through `from_attention_mha` by detecting `rms_norm_class is RMSNormCoreML`. The QNN/HTP export path is untouched and continues to use `torch.nn.RMSNorm`. Differential Revision: D104862210
b2acb39 to
b5af889
Compare
| self.weight.requires_grad = False | ||
|
|
||
|
|
||
| class RMSNormCoreML(torch.nn.Module): |
There was a problem hiding this comment.
How does this differ from:
Can we consolidate? Putting it here is fine, but then import this version into examples/apple/coreml/llama/llama_transformer.py.
There was a problem hiding this comment.
I imported the new version in examples/apple/coreml/llama/llama_transformer.py because it was tested to not produce NaN in QAT
| ) | ||
| if self.has_kv_weights: | ||
| self.k_norm_fn = RMSNorm( | ||
| if args.use_coreml_norm: |
There was a problem hiding this comment.
Does this have to be integrated so far down?
Could we not leave llama_transformer/static attention as is, introduce the new norm in norm.py, and then do a module swap from RMSNorm -> CoreMLRMsNorm at export time?
There was a problem hiding this comment.
addresed and now using the strategy to replace_rms_norm_for_coreml_
Summary: The standard `RMSNorm` formulation `x * rsqrt(mean(x²)) * weight` is numerically unstable on CoreML/ANE because the explicit FP32 cast around the mean reduction is silently stripped from the lowered graph, leaving the squared sum to overflow in FP16. The ANE PTE then diverges from the eager reference even on checkpoints fine-tuned in BF16/FP16. This diff introduces `RMSNormCoreML` in `examples/models/llama/norm.py`. The module expresses the normalization as `x * sqrt(d) / vector_norm(x, dim=-1)` — `torch.linalg.vector_norm` keeps the reduction in a single op that survives CoreML lowering, so FP16 inference remains stable. To avoid `0 / 0 = NaN` on zero-padded positions (chunked prefill in `StaticAttentionIOManager` pads each chunk to `input_len` with zeros), the denominator is floored with `sqrt(dim * eps)`. This matches standard RMSNorm's `rsqrt(mean(x²) + eps)` semantics on a zero input and is large enough to survive fp16 — a plain `1e-6` underflows. Real (non-zero) tokens satisfy `vector_norm(x) >> sqrt(dim * eps)`, so the floor is a no-op on real positions. A new `use_coreml_norm: bool = False` field on `ModelArgs` opts into the new norm without disturbing existing models. When True, every llama-side norm site constructs `RMSNormCoreML`: - `llama_transformer.py`: `attention_norm`, `ffn_norm`, the final `self.norm` on `Transformer`. - `attention.py`: `q_norm_fn` / `k_norm_fn` in the affine QK-norm path, AND the `else` branch of `_init_qk_norms` (the scaleless / non-affine QK-norm path that the original landing missed). - `static_attention.py`: `q_norm` / `k_norm` in the scaleless path, propagated through `from_attention_mha` by detecting `rms_norm_class is RMSNormCoreML`. The QNN/HTP export path is untouched and continues to use `torch.nn.RMSNorm`. Differential Revision: D104862210
b5af889 to
ae1926c
Compare
Summary: The standard `RMSNorm` formulation `x * rsqrt(mean(x²)) * weight` is numerically unstable on CoreML/ANE because the explicit FP32 cast around the mean reduction is silently stripped from the lowered graph, leaving the squared sum to overflow in FP16. The ANE PTE then diverges from the eager reference even on checkpoints fine-tuned in BF16/FP16. This diff introduces `RMSNormCoreML` in `examples/models/llama/norm.py`. The module expresses the normalization as `x * sqrt(d) / vector_norm(x, dim=-1)` — `torch.linalg.vector_norm` keeps the reduction in a single op that survives CoreML lowering, so FP16 inference remains stable. To avoid `0 / 0 = NaN` on zero-padded positions (chunked prefill in `StaticAttentionIOManager` pads each chunk to `input_len` with zeros), the denominator is floored with `sqrt(dim * eps)`. This matches standard RMSNorm's `rsqrt(mean(x²) + eps)` semantics on a zero input and is large enough to survive fp16 — a plain `1e-6` underflows. Real (non-zero) tokens satisfy `vector_norm(x) >> sqrt(dim * eps)`, so the floor is a no-op on real positions. A new `use_coreml_norm: bool = False` field on `ModelArgs` opts into the new norm without disturbing existing models. When True, every llama-side norm site constructs `RMSNormCoreML`: - `llama_transformer.py`: `attention_norm`, `ffn_norm`, the final `self.norm` on `Transformer`. - `attention.py`: `q_norm_fn` / `k_norm_fn` in the affine QK-norm path, AND the `else` branch of `_init_qk_norms` (the scaleless / non-affine QK-norm path that the original landing missed). - `static_attention.py`: `q_norm` / `k_norm` in the scaleless path, propagated through `from_attention_mha` by detecting `rms_norm_class is RMSNormCoreML`. The QNN/HTP export path is untouched and continues to use `torch.nn.RMSNorm`. Differential Revision: D104862210
ae1926c to
fc26253
Compare
|
LGTM! You need to run the lintrunner, though |
|
@telgamal-1 has imported this pull request. If you are a Meta employee, you can view this in D104862210. |
This PR was reopened (likely due to being reverted), so your approval was removed. Please request another review.
Summary:
The standard
RMSNormformulationx * rsqrt(mean(x²)) * weightis numerically unstable on CoreML/ANE because the explicit FP32 cast around the mean reduction is silently stripped from the lowered graph, leaving the squared sum to overflow in FP16. The ANE PTE then diverges from the eager reference even on checkpoints fine-tuned in BF16/FP16.This diff introduces
RMSNormCoreMLinexamples/models/llama/norm.py. The module expresses the normalization asx * sqrt(d) / vector_norm(x, dim=-1)—torch.linalg.vector_normkeeps the reduction in a single op that survives CoreML lowering, so FP16 inference remains stable.To avoid
0 / 0 = NaNon zero-padded positions (chunked prefill inStaticAttentionIOManagerpads each chunk toinput_lenwith zeros), the denominator is floored withsqrt(dim * eps). This matches standard RMSNorm'srsqrt(mean(x²) + eps)semantics on a zero input and is large enough to survive fp16 — a plain1e-6underflows. Real (non-zero) tokens satisfyvector_norm(x) >> sqrt(dim * eps), so the floor is a no-op on real positions.A new
use_coreml_norm: bool = Falsefield onModelArgsopts into the new norm without disturbing existing models. When True, every llama-side norm site constructsRMSNormCoreML:llama_transformer.py:attention_norm,ffn_norm, the finalself.normonTransformer.attention.py:q_norm_fn/k_norm_fnin the affine QK-norm path, AND theelsebranch of_init_qk_norms(the scaleless / non-affine QK-norm path that the original landing missed).static_attention.py:q_norm/k_normin the scaleless path, propagated throughfrom_attention_mhaby detectingrms_norm_class is RMSNormCoreML.The QNN/HTP export path is untouched and continues to use
torch.nn.RMSNorm.Differential Revision: D104862210