Skip to content

Commit b2acb39

Browse files
telgamal-1facebook-github-bot
authored andcommitted
Add CoreML-stable RMSNorm for llama eager paths (pytorch#19523)
Summary: The standard `RMSNorm` formulation `x * rsqrt(mean(x²)) * weight` is numerically unstable on CoreML/ANE because the explicit FP32 cast around the mean reduction is silently stripped from the lowered graph, leaving the squared sum to overflow in FP16. The ANE PTE then diverges from the eager reference even on checkpoints fine-tuned in BF16/FP16. This diff introduces `RMSNormCoreML` in `examples/models/llama/norm.py`. The module expresses the normalization as `x * sqrt(d) / vector_norm(x, dim=-1)` — `torch.linalg.vector_norm` keeps the reduction in a single op that survives CoreML lowering, so FP16 inference remains stable. To avoid `0 / 0 = NaN` on zero-padded positions (chunked prefill in `StaticAttentionIOManager` pads each chunk to `input_len` with zeros), the denominator is floored with `sqrt(dim * eps)`. This matches standard RMSNorm's `rsqrt(mean(x²) + eps)` semantics on a zero input and is large enough to survive fp16 — a plain `1e-6` underflows. Real (non-zero) tokens satisfy `vector_norm(x) >> sqrt(dim * eps)`, so the floor is a no-op on real positions. A new `use_coreml_norm: bool = False` field on `ModelArgs` opts into the new norm without disturbing existing models. When True, every llama-side norm site constructs `RMSNormCoreML`: - `llama_transformer.py`: `attention_norm`, `ffn_norm`, the final `self.norm` on `Transformer`. - `attention.py`: `q_norm_fn` / `k_norm_fn` in the affine QK-norm path, AND the `else` branch of `_init_qk_norms` (the scaleless / non-affine QK-norm path that the original landing missed). - `static_attention.py`: `q_norm` / `k_norm` in the scaleless path, propagated through `from_attention_mha` by detecting `rms_norm_class is RMSNormCoreML`. The QNN/HTP export path is untouched and continues to use `torch.nn.RMSNorm`. Differential Revision: D104862210
1 parent 7cd209d commit b2acb39

5 files changed

Lines changed: 105 additions & 22 deletions

File tree

examples/models/llama/attention.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from executorch.examples.models.llama.model_args import ModelArgs
1010
from executorch.examples.models.llama.norm import (
1111
RMSNorm,
12+
RMSNormCoreML,
1213
RMSNormGated,
1314
ScalelessRMSNorm,
1415
)
@@ -425,21 +426,29 @@ def _init_norms(self, args: ModelArgs) -> None:
425426
"""Initialize QK normalization layers."""
426427
if self.use_qk_norm:
427428
if args.qk_norm_affine:
428-
self.q_norm_fn = RMSNorm(
429-
self.head_dim,
430-
eps=args.norm_eps,
431-
add_unit_offset=args.rms_norm_add_unit_offset,
432-
)
433-
if self.has_kv_weights:
434-
self.k_norm_fn = RMSNorm(
429+
if args.use_coreml_norm:
430+
self.q_norm_fn = RMSNormCoreML(self.head_dim, eps=args.norm_eps)
431+
if self.has_kv_weights:
432+
self.k_norm_fn = RMSNormCoreML(
433+
self.head_dim, eps=args.norm_eps
434+
)
435+
else:
436+
self.q_norm_fn = RMSNorm(
435437
self.head_dim,
436438
eps=args.norm_eps,
437439
add_unit_offset=args.rms_norm_add_unit_offset,
438440
)
441+
if self.has_kv_weights:
442+
self.k_norm_fn = RMSNorm(
443+
self.head_dim,
444+
eps=args.norm_eps,
445+
add_unit_offset=args.rms_norm_add_unit_offset,
446+
)
439447
else:
440-
self.q_norm_fn = ScalelessRMSNorm(self.head_dim, eps=args.norm_eps)
448+
cls = RMSNormCoreML if args.use_coreml_norm else ScalelessRMSNorm
449+
self.q_norm_fn = cls(self.head_dim, eps=args.norm_eps)
441450
if self.has_kv_weights:
442-
self.k_norm_fn = ScalelessRMSNorm(self.head_dim, eps=args.norm_eps)
451+
self.k_norm_fn = cls(self.head_dim, eps=args.norm_eps)
443452
if self.use_attn_o_norm:
444453
self.o_norm = ScalelessRMSNorm(self.head_dim, eps=args.norm_eps)
445454
if self.use_attn_o_gate:

examples/models/llama/llama_transformer.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from executorch.examples.models.llama.model_args import ModelArgs
2424
from executorch.examples.models.llama.norm import (
2525
RMSNorm,
26+
RMSNormCoreML,
2627
RMSNormWithInputScale,
2728
ScalelessRMSNorm,
2829
)
@@ -168,18 +169,23 @@ def __init__(
168169

169170
if isinstance(self.attention, AttentionSkip):
170171
self.attention_norm = nn.Identity()
172+
elif args.use_coreml_norm:
173+
self.attention_norm = RMSNormCoreML(args.dim, eps=args.norm_eps)
171174
else:
172175
self.attention_norm = RMSNorm(
173176
args.dim,
174177
eps=args.norm_eps,
175178
add_unit_offset=args.rms_norm_add_unit_offset,
176179
)
177180
if self.mlp_type != "skip":
178-
self.ffn_norm = RMSNorm(
179-
args.dim,
180-
eps=args.norm_eps,
181-
add_unit_offset=args.rms_norm_add_unit_offset,
182-
)
181+
if args.use_coreml_norm:
182+
self.ffn_norm = RMSNormCoreML(args.dim, eps=args.norm_eps)
183+
else:
184+
self.ffn_norm = RMSNorm(
185+
args.dim,
186+
eps=args.norm_eps,
187+
add_unit_offset=args.rms_norm_add_unit_offset,
188+
)
183189

184190
if args.use_residual_gate:
185191
attn_init = 1.0 / (2 * layer_id + 1) if layer_id > 0 else 0.5
@@ -273,11 +279,14 @@ def __init__(self, params: ModelArgs, layers: nn.ModuleList, rope: Rope):
273279
)
274280
self.layers = layers
275281
self.rope = rope
276-
self.norm = RMSNorm(
277-
params.dim,
278-
eps=params.norm_eps,
279-
add_unit_offset=params.rms_norm_add_unit_offset,
280-
)
282+
if params.use_coreml_norm:
283+
self.norm = RMSNormCoreML(params.dim, eps=params.norm_eps)
284+
else:
285+
self.norm = RMSNorm(
286+
params.dim,
287+
eps=params.norm_eps,
288+
add_unit_offset=params.rms_norm_add_unit_offset,
289+
)
281290
self.output = (
282291
nn.Linear(params.dim, params.vocab_size, bias=False)
283292
if self.apply_output

examples/models/llama/model_args.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,10 @@ class ModelArgs:
7676
False # Use q-gated projection in attention (Qwen3.5 full attention)
7777
)
7878
norm_type: str = "rmsnorm" # Normalization type, registered in norm.py
79+
# When True, swap RMSNorm for the CoreML-friendly RMSNormCoreML at every
80+
# norm site. The CoreML formulation uses torch.linalg.vector_norm so the
81+
# op is preserved in the CoreML graph (FP32 casts get stripped by CoreML).
82+
use_coreml_norm: bool = False
7983
act_fn: ActFn = dataclasses.field(default=ActFn.SILU) # Activation function type
8084
attention_qkv_bias: bool = False
8185
use_kv_cache: bool = False # Use key/value cache

examples/models/llama/norm.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,50 @@ def __init__(self, dim: int, eps: float = 1e-6):
5757
self.weight.requires_grad = False
5858

5959

60+
class RMSNormCoreML(torch.nn.Module):
61+
def __init__(self, dim: int, eps: float = 1e-6):
62+
"""
63+
CoreML-friendly RMSNorm — uses `torch.linalg.vector_norm` so the op is
64+
preserved in the CoreML graph for numerical stability.
65+
66+
Args:
67+
dim (int): The dimension of the input tensor.
68+
eps (float, optional): Floor on the L2-norm denominator
69+
(`clamp_min(‖x‖₂, √(dim·eps))`). Prevents `0/0 = NaN` on
70+
zero-padded positions and matches standard RMSNorm's
71+
`rsqrt(mean(x²) + eps)` semantics on a zero input. Must be > 0.
72+
73+
Attributes:
74+
eps (float): Floor coefficient consumed by `_norm`.
75+
weight (nn.Parameter): Learnable scaling parameter.
76+
"""
77+
super().__init__()
78+
self.dim = dim
79+
self.eps = eps
80+
self.weight = nn.Parameter(torch.ones(dim))
81+
82+
def _norm(self, x):
83+
# Floor the denominator to avoid 0 / 0 = NaN on zero-padded positions
84+
# (chunked prefill in StaticAttentionIOManager pads each chunk to
85+
# input_len with zeros). Use sqrt(dim * eps) so the floor matches
86+
# standard RMSNorm's eps semantics (`rsqrt(mean(x²) + eps)`) and is
87+
# large enough to survive fp16 (1e-6 alone underflows in fp16).
88+
floor_val = torch.sqrt(torch.tensor(self.dim * self.eps, dtype=x.dtype))
89+
norm_val = torch.clamp_min(
90+
torch.linalg.vector_norm(x, dim=-1, keepdim=True), floor_val
91+
)
92+
rms_norm_eps0 = (
93+
x
94+
* torch.sqrt(torch.tensor(self.dim, dtype=x.dtype))
95+
* torch.reciprocal(norm_val)
96+
)
97+
return rms_norm_eps0
98+
99+
def forward(self, x):
100+
output = self._norm(x)
101+
return output * self.weight
102+
103+
60104
class RMSNormWithInputScale(torch.nn.Module):
61105
def __init__(self, dim: int, eps: float = 1e-5):
62106
super().__init__()

examples/models/llama/static_attention.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
)
1616
from executorch.examples.models.llama.lora import LoRALinear
1717
from executorch.examples.models.llama.model_args import ModelArgs
18-
from executorch.examples.models.llama.norm import ScalelessRMSNorm
18+
from executorch.examples.models.llama.norm import RMSNormCoreML, ScalelessRMSNorm
1919
from executorch.examples.models.llama.rope import Rope
2020

2121

@@ -898,18 +898,26 @@ def _init_wo(self, config: ModelArgs) -> None:
898898

899899
def _init_qk_norms(self, config: ModelArgs, is_kv_shared_layer: bool) -> None:
900900
if self.use_qk_norm:
901+
# When use_coreml_norm is set, match the rlformers reference path
902+
# which constructs q_norm/k_norm via RMSNormCoreML (no fp32 cast,
903+
# no eps, vector_norm-based) instead of ScalelessRMSNorm.
904+
_scaleless_cls = (
905+
RMSNormCoreML
906+
if getattr(config, "use_coreml_norm", False)
907+
else ScalelessRMSNorm
908+
)
901909
if getattr(config, "qk_norm_affine", True):
902910
self.q_norm = torch.nn.RMSNorm(self.head_dim, config.norm_eps)
903911
if is_kv_shared_layer:
904912
self.k_norm = nn.Identity()
905913
else:
906914
self.k_norm = torch.nn.RMSNorm(self.head_dim, config.norm_eps)
907915
else:
908-
self.q_norm = ScalelessRMSNorm(self.head_dim, eps=config.norm_eps)
916+
self.q_norm = _scaleless_cls(self.head_dim, eps=config.norm_eps)
909917
if is_kv_shared_layer:
910918
self.k_norm = nn.Identity()
911919
else:
912-
self.k_norm = ScalelessRMSNorm(self.head_dim, eps=config.norm_eps)
920+
self.k_norm = _scaleless_cls(self.head_dim, eps=config.norm_eps)
913921
else:
914922
self.q_norm = torch.nn.Identity()
915923
self.k_norm = torch.nn.Identity()
@@ -949,6 +957,14 @@ def from_attention_mha(
949957
hasattr(other.q_norm_fn, "weight") if other.use_qk_norm else True
950958
)
951959

960+
# Propagate use_coreml_norm so _init_qk_norms picks RMSNormCoreML for
961+
# scaleless q/k norms (matches the rlformers reference path). Detect
962+
# via the rms_norm_class kwarg — `transform_attention_mha_to_static_attention`
963+
# forwards it through, and the static_transformer_export caller already
964+
# selects RMSNormCoreML when use_coreml_norm is set on the model args.
965+
from executorch.examples.models.llama.norm import RMSNormCoreML
966+
_use_coreml_norm = rms_norm_class is RMSNormCoreML
967+
952968
config = ModelArgs(
953969
dim=other.dim,
954970
n_layers=1, # Not used in attention layer
@@ -964,6 +980,7 @@ def from_attention_mha(
964980
norm_eps=other.q_norm_fn.eps if other.use_qk_norm else 1e-5,
965981
num_kv_shared_layers=getattr(other, "num_kv_shared_layers", 0),
966982
scale_query_by=getattr(other, "scale_query_by", 1.0),
983+
use_coreml_norm=_use_coreml_norm,
967984
)
968985

969986
instance = cls(

0 commit comments

Comments
 (0)