You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Add CoreML-stable RMSNorm for llama eager paths (#19523)
Summary:
The standard `RMSNorm` formulation `x * rsqrt(mean(x²)) * weight` is numerically unstable on CoreML/ANE because the explicit FP32 cast around the mean reduction is silently stripped from the lowered graph, leaving the squared sum to overflow in FP16. The ANE PTE then diverges from the eager reference even on checkpoints fine-tuned in BF16/FP16.
This diff introduces `RMSNormCoreML` in `examples/models/llama/norm.py`. The module expresses the normalization as `x * sqrt(d) / vector_norm(x, dim=-1)` — `torch.linalg.vector_norm` keeps the reduction in a single op that survives CoreML lowering, so FP16 inference remains stable.
To avoid `0 / 0 = NaN` on zero-padded positions (chunked prefill in `StaticAttentionIOManager` pads each chunk to `input_len` with zeros), the denominator is floored with `sqrt(dim * eps)`. This matches standard RMSNorm's `rsqrt(mean(x²) + eps)` semantics on a zero input and is large enough to survive fp16 — a plain `1e-6` underflows. Real (non-zero) tokens satisfy `vector_norm(x) >> sqrt(dim * eps)`, so the floor is a no-op on real positions.
A new `use_coreml_norm: bool = False` field on `ModelArgs` opts into the new norm without disturbing existing models. When True, every llama-side norm site constructs `RMSNormCoreML`:
- `llama_transformer.py`: `attention_norm`, `ffn_norm`, the final `self.norm` on `Transformer`.
- `attention.py`: `q_norm_fn` / `k_norm_fn` in the affine QK-norm path, AND the `else` branch of `_init_qk_norms` (the scaleless / non-affine QK-norm path that the original landing missed).
- `static_attention.py`: `q_norm` / `k_norm` in the scaleless path, propagated through `from_attention_mha` by detecting `rms_norm_class is RMSNormCoreML`.
The QNN/HTP export path is untouched and continues to use `torch.nn.RMSNorm`.
Differential Revision: D104862210
0 commit comments