Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 1 addition & 60 deletions examples/apple/coreml/llama/llama_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import torch
import torch.nn.functional as F
from executorch.examples.models.llama.norm import RMSNorm
from executorch.examples.models.llama.norm import RMSNorm, RMSNormCoreML

Check warning on line 17 in examples/apple/coreml/llama/llama_transformer.py

View workflow job for this annotation

GitHub Actions / lintrunner

FLAKE8 F401

'executorch.examples.models.llama.norm.RMSNormCoreML' imported but unused See https://www.flake8rules.com/rules/F401.html.

Check warning on line 17 in examples/apple/coreml/llama/llama_transformer.py

View workflow job for this annotation

GitHub Actions / lintrunner

FLAKE8 F401

'executorch.examples.models.llama.norm.RMSNormCoreML' imported but unused See https://www.flake8rules.com/rules/F401.html.

from executorch.examples.models.llama.rope import (
hf_apply_rotary_emb,
Expand Down Expand Up @@ -109,65 +109,6 @@
self.head_dim = self.dim // self.n_heads


class CoreMLRMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
"""
Initialize the RMSNorm normalization layer.

Args:
dim (int): The dimension of the input tensor.
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.

Attributes:
eps (float): A small value added to the denominator for numerical stability.
weight (nn.Parameter): Learnable scaling parameter.

"""
super().__init__()
self.dim = dim
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))

def _norm(self, x):
"""
Apply the RMSNorm normalization to the input tensor.

Args:
x (torch.Tensor): The input tensor.

Returns:
torch.Tensor: The normalized tensor.

"""
# CoreML ignores casts to FP32, so existing implementation of RMSNorm was not stable
# We instead use (x * sqrt(n)) / norm(x, dim=-1)
# Using torch.norm and preserving this op in CoreML improves stability
# Note, we ignore eps, but could add it by using torch.norm(torch.concat(x, sqrt(n*eps))) in the denominator
# In future, we want to add CoreML support for the functional RMSNorm op
# We have yet to do large scale evaluations on the numeric stability of this solution, but note that
# it appears better than what exists currently (removing FP32 casts and using FP16)
rms_norm_eps0 = (
x
* torch.sqrt(torch.tensor(self.dim, dtype=x.dtype))
* torch.reciprocal(torch.linalg.vector_norm(x, dim=-1, keepdim=True))
)
return rms_norm_eps0

def forward(self, x):
"""
Forward pass through the RMSNorm layer.

Args:
x (torch.Tensor): The input tensor.

Returns:
torch.Tensor: The output tensor after applying RMSNorm.

"""
output = self._norm(x)
return output * self.weight


class Rope(torch.nn.Module):
def __init__(self, params: ModelArgs):
super().__init__()
Expand Down
78 changes: 78 additions & 0 deletions examples/models/llama/norm.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
Expand Down Expand Up @@ -57,6 +57,51 @@
self.weight.requires_grad = False


class RMSNormCoreML(torch.nn.Module):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does this differ from:

class CoreMLRMSNorm(torch.nn.Module):

Can we consolidate? Putting it here is fine, but then import this version into examples/apple/coreml/llama/llama_transformer.py.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I imported the new version in examples/apple/coreml/llama/llama_transformer.py because it was tested to not produce NaN in QAT

def __init__(self, dim: int, eps: float = 1e-6):
"""
CoreML-friendly RMSNorm — uses `torch.linalg.vector_norm` so the op is
preserved in the CoreML graph for numerical stability.

Args:
dim (int): The dimension of the input tensor.
eps (float, optional): Floor on the L2-norm denominator
(`clamp_min(‖x‖₂, √(dim·eps))`). Prevents `0/0 = NaN` on
zero-padded positions and matches standard RMSNorm's
`rsqrt(mean(x²) + eps)` semantics on a zero input. Must be > 0.

Attributes:
eps (float): Floor coefficient consumed by `_norm`.
weight (nn.Parameter): Learnable scaling parameter.
"""
super().__init__()
assert eps > 0, "RMSNormCoreML requires eps > 0; eps=0 collapses the denominator floor and produces NaN on zero-padded positions"
self.dim = dim
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))

def _norm(self, x):
# Floor the denominator to avoid 0 / 0 = NaN on zero-padded positions
# (chunked prefill in StaticAttentionIOManager pads each chunk to
# input_len with zeros). Use sqrt(dim * eps) so the floor matches
# standard RMSNorm's eps semantics (`rsqrt(mean(x²) + eps)`) and is
# large enough to survive fp16 (1e-6 alone underflows in fp16).
floor_val = torch.sqrt(torch.tensor(self.dim * self.eps, dtype=x.dtype))
norm_val = torch.clamp_min(
torch.linalg.vector_norm(x, dim=-1, keepdim=True), floor_val
)
rms_norm_eps0 = (
x
* torch.sqrt(torch.tensor(self.dim, dtype=x.dtype))
* torch.reciprocal(norm_val)
)
return rms_norm_eps0

def forward(self, x):
output = self._norm(x)
return output * self.weight


class RMSNormWithInputScale(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-5):
super().__init__()
Expand All @@ -83,3 +128,36 @@
hidden_states = self.weight * hidden_states.to(input_dtype)
hidden_states = hidden_states * F.silu(gate.to(torch.float32))
return hidden_states.to(input_dtype)


def replace_rms_norm_for_coreml_(model: torch.nn.Module) -> torch.nn.Module:
"""In-place: walk `model` and swap every RMSNorm-family module for RMSNormCoreML.

Mirrors the post-construction transform pattern used by torchao's
`quantize_(model, config)`: instead of threading a `use_coreml_norm` flag
through every norm construction site, build the model with the standard
norms and then call this once before CoreML export. Trained scale weights
are preserved.

Swaps these classes (everything else is left alone):
* `RMSNorm` (this module)
* `ScalelessRMSNorm` (this module — no-op weight)
* `torch.nn.RMSNorm` (used for affine q_norm/k_norm in StaticAttention)
"""
for name, mod in list(model.named_modules()):
if not isinstance(mod, (RMSNorm, ScalelessRMSNorm, torch.nn.RMSNorm)):
continue
# All three carry the normalized dim either as `dim` or in `normalized_shape[-1]`.
dim = getattr(mod, "dim", None) or mod.normalized_shape[-1]
eps = getattr(mod, "eps", 1e-6) or 1e-6
new = RMSNormCoreML(dim, eps=eps)
if getattr(mod, "weight", None) is not None:
new.weight = mod.weight # preserve trained scale (no-op for ScalelessRMSNorm)
# Locate parent module via the dotted name and rebind the attribute.
if "." in name:
parent_name, attr = name.rsplit(".", 1)
parent = model.get_submodule(parent_name)
else:
parent, attr = model, name
setattr(parent, attr, new)
return model
Loading