From 3d5f0d6fe9fc42a624f4b814681b6acf33b46294 Mon Sep 17 00:00:00 2001 From: Igor Fedorov Date: Wed, 29 Apr 2026 13:42:39 -0700 Subject: [PATCH] Add rlformers forward-pass features to ExecuTorch backbone for on-device export parity (#19096) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: The 730M dense model checkpoint uses several features that the ExecuTorch XNNPACK export path did not implement. Without these, the exported model produces numerically incorrect output. This diff adds support for 8 missing features: 1. `normalize_tok_embeddings` — scaleless RMSNorm after embedding lookup 2. `qk_norm_before_rope` — conversion from GenAI args (attention code already supported it) 3. `scale_query_by` — custom scalar multiplier on Q after QK norm 4. `use_attn_o_gate` — sigmoid gate on attention output using a learned linear projection of the layer input 5. `use_attn_o_norm` — scaleless per-head RMSNorm on attention output (applied before o_gate) 6. `use_residual_gate` — NormPreservingResidualConnection with learned per-dim gates for both attention and FFN residual connections 7. `use_ffn_learnable_scales` — RMSNormWithInputScale replacing standard post-FFN norm, computing `rms_norm(gamma * x)` instead of `gamma * rms_norm(x)` 8. `output_soft_cap_temp` — `tanh(logits/temp) * temp` soft capping on output logits Additionally, this diff fixes a QK norm checkpoint compatibility issue: some checkpoints contain learned QK norm weights even though their `params.json` has `qk_norm_affine=False` (due to default changes after training). The ET model was creating `ScalelessRMSNorm` (no weight parameter) based on `params.json`, silently discarding the checkpoint's trained QK norm weights. The rlformers reference model loaded them correctly, causing ~53-67 dB SNR divergence. The fix peeks at the checkpoint state dict before model construction — if QK norm weights are present, `qk_norm_affine` is overridden to `True` so the ET model creates affine QK norms that load those weights. All features are off by default (backward compatible). They activate when the corresponding fields are set in the checkpoint's params.json and propagated through model_args_conversion. Weight key mappings added for: `attention.og.weight`, `add_attn.gate`, `add_ffn.gate`, `post_ffn_norm.weight`. Reviewed By: chinnadhurai, digantdesai Differential Revision: D102030169 --- examples/models/llama/attention.py | 61 ++++++++++++--- examples/models/llama/llama_transformer.py | 86 +++++++++++++++++++--- examples/models/llama/model_args.py | 12 +++ examples/models/llama/norm.py | 43 ++++++----- examples/models/llama/static_attention.py | 52 +++++++++---- 5 files changed, 201 insertions(+), 53 deletions(-) diff --git a/examples/models/llama/attention.py b/examples/models/llama/attention.py index d6dff173072..d43533b5a70 100644 --- a/examples/models/llama/attention.py +++ b/examples/models/llama/attention.py @@ -7,7 +7,11 @@ import torch.nn.functional as F from executorch.examples.models.llama.lora import LoRALinear from executorch.examples.models.llama.model_args import ModelArgs -from executorch.examples.models.llama.norm import RMSNorm, RMSNormGated +from executorch.examples.models.llama.norm import ( + RMSNorm, + RMSNormGated, + ScalelessRMSNorm, +) from executorch.examples.models.llama.rope import Rope @@ -375,6 +379,9 @@ def __init__( self.qk_norm_before_rope = args.qk_norm_before_rope self.use_q_gate = args.use_q_gate self.enable_dynamic_shape = args.enable_dynamic_shape + self.scale_query_by = args.scale_query_by + self.use_attn_o_gate = args.use_attn_o_gate + self.use_attn_o_norm = args.use_attn_o_norm q_out_dim = self.n_heads * self.head_dim * (2 if self.use_q_gate else 1) # YOCO: Determine if this is a KV shared layer (receives shared KV from donor). @@ -417,17 +424,26 @@ def __init__( def _init_norms(self, args: ModelArgs) -> None: """Initialize QK normalization layers.""" if self.use_qk_norm: - self.q_norm_fn = RMSNorm( - self.head_dim, - eps=args.norm_eps, - add_unit_offset=args.rms_norm_add_unit_offset, - ) - if self.has_kv_weights: - self.k_norm_fn = RMSNorm( + if args.qk_norm_affine: + self.q_norm_fn = RMSNorm( self.head_dim, eps=args.norm_eps, add_unit_offset=args.rms_norm_add_unit_offset, ) + if self.has_kv_weights: + self.k_norm_fn = RMSNorm( + self.head_dim, + eps=args.norm_eps, + add_unit_offset=args.rms_norm_add_unit_offset, + ) + else: + self.q_norm_fn = ScalelessRMSNorm(self.head_dim, eps=args.norm_eps) + if self.has_kv_weights: + self.k_norm_fn = ScalelessRMSNorm(self.head_dim, eps=args.norm_eps) + if self.use_attn_o_norm: + self.o_norm = ScalelessRMSNorm(self.head_dim, eps=args.norm_eps) + if self.use_attn_o_gate: + self.og = nn.Linear(args.dim, self.n_heads * self.head_dim, bias=False) def _init_projections(self, args: ModelArgs, q_out_dim: int) -> None: """Initialize Q/K/V/O projection layers.""" @@ -478,6 +494,8 @@ def _prepare_qkv_shared( if self.use_qk_norm and self.qk_norm_before_rope: q = self.q_norm_fn(q) + if self.scale_query_by != 1.0: + q = q * self.scale_query_by # Apply RoPE to Q only (K already has RoPE from donor layer) q, _ = self.rope.forward(q, q, freqs_cos, freqs_sin) @@ -485,6 +503,8 @@ def _prepare_qkv_shared( if self.use_qk_norm and not self.qk_norm_before_rope: q = self.q_norm_fn(q) + if self.scale_query_by != 1.0: + q = q * self.scale_query_by return q, k, v @@ -508,6 +528,8 @@ def _prepare_qkv( if self.use_qk_norm and self.qk_norm_before_rope: q = self.q_norm_fn(q) + if self.scale_query_by != 1.0: + q = q * self.scale_query_by k = self.k_norm_fn(k) q, k = self.rope.forward(q, k, freqs_cos, freqs_sin) @@ -518,6 +540,8 @@ def _prepare_qkv( if self.use_qk_norm and not self.qk_norm_before_rope: q = self.q_norm_fn(q) + if self.scale_query_by != 1.0: + q = q * self.scale_query_by k = self.k_norm_fn(k) return q, k, v @@ -582,8 +606,7 @@ def forward( ) output = self.SDPA(input_pos, q, k, v, bsz, seqlen, attn_mask) - if gate is not None: - output = output * torch.sigmoid(gate) + output = self._apply_output_transforms(output, x, gate, bsz, seqlen) if shared_kv is None and self.num_kv_shared_layers > 0: update = {"kv_to_share": (k, v)} @@ -602,13 +625,27 @@ def forward( output = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) output = output.transpose(1, 2).reshape(bsz, seqlen, -1) - if gate is not None: - output = output * torch.sigmoid(gate) + output = self._apply_output_transforms(output, x, gate, bsz, seqlen) output = self.wo(output) return output, None + def _apply_output_transforms( + self, output: torch.Tensor, x: torch.Tensor, gate, bsz: int, seqlen: int + ) -> torch.Tensor: + if self.use_attn_o_norm or self.use_attn_o_gate: + output_4d = output.view(bsz, seqlen, self.n_local_heads, self.head_dim) + if self.use_attn_o_norm: + output_4d = self.o_norm(output_4d) + if self.use_attn_o_gate: + og = self.og(x).view(bsz, seqlen, self.n_local_heads, self.head_dim) + output_4d = torch.sigmoid(og) * output_4d + output = output_4d.reshape(bsz, seqlen, -1) + if gate is not None: + output = output * torch.sigmoid(gate) + return output + def _l2norm(x: torch.Tensor, dim: int = -1, eps: float = 1e-6) -> torch.Tensor: inv_norm = torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps) diff --git a/examples/models/llama/llama_transformer.py b/examples/models/llama/llama_transformer.py index e74ae810a02..d87eef3f906 100644 --- a/examples/models/llama/llama_transformer.py +++ b/examples/models/llama/llama_transformer.py @@ -7,6 +7,7 @@ # Please refer to README.md in the same folder for more information. +import math from typing import Any, Dict, Optional, Tuple, Union import torch @@ -20,7 +21,11 @@ ) from executorch.examples.models.llama.feed_forward import FeedForward, LoRAFeedForward from executorch.examples.models.llama.model_args import ModelArgs -from executorch.examples.models.llama.norm import RMSNorm +from executorch.examples.models.llama.norm import ( + RMSNorm, + RMSNormWithInputScale, + ScalelessRMSNorm, +) from executorch.examples.models.llama.rope import Rope from torch import nn @@ -51,6 +56,26 @@ def _is_kv_shared_layer( return layer_idx >= first_shared and first_shared > 0 +class NormPreservingResidualConnection(nn.Module): + def __init__( + self, dim: int, init_scale: float, temperature: float = 0.3, eps: float = 1e-3 + ): + super().__init__() + self.eps = eps + self.temperature = temperature + p = max(0.0 + eps, min(1.0 - eps, init_scale)) + init_param = math.log(p / (1.0 - p)) * temperature + self.gate = nn.Parameter(torch.full((dim,), init_param)) + + def forward(self, stream: torch.Tensor, branch: torch.Tensor) -> torch.Tensor: + dtype = stream.dtype + w = self.gate.view(*([1] * (stream.ndim - 1)), -1).float() + beta = torch.sigmoid(w / self.temperature) + alpha_sq = torch.sigmoid(-w / self.temperature) * (1.0 + beta) + alpha = torch.sqrt(torch.clamp(alpha_sq, min=self.eps)) + return (alpha * stream.float() + beta * branch.float()).to(dtype) + + class ConditionalFeedForward(nn.Module): def __init__(self, args: ModelArgs): super().__init__() @@ -99,7 +124,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class TransformerBlock(nn.Module): def __init__( - self, args: ModelArgs, attention: Attention, mlp_type: str = "default" + self, + args: ModelArgs, + attention: Attention, + mlp_type: str = "default", + layer_id: int = 0, ): """ Transformer block with support for pre-norm and post-norm. @@ -110,6 +139,7 @@ def __init__( the attention type is registered in the ATTENTION_REGISTRY. mlp_type (str): MLP type for this layer. "default" for standard FFN, "skip" for no FFN block. + layer_id (int): layer index, used for residual gate initialization. """ super().__init__() self.use_kv_cache = args.use_kv_cache @@ -118,6 +148,7 @@ def __init__( self.head_dim = args.head_dim self.attention = attention self.mlp_type = mlp_type.lower() + self.use_residual_gate = args.use_residual_gate assert ( args.hidden_dim is not None @@ -150,6 +181,20 @@ def __init__( add_unit_offset=args.rms_norm_add_unit_offset, ) + if args.use_residual_gate: + attn_init = 1.0 / (2 * layer_id + 1) if layer_id > 0 else 0.5 + ffn_init = 1.0 / (2 * layer_id + 2) + self.add_attn = NormPreservingResidualConnection( + dim=args.dim, init_scale=attn_init + ) + self.add_ffn = NormPreservingResidualConnection( + dim=args.dim, init_scale=ffn_init + ) + self.post_attn_norm = ScalelessRMSNorm(args.dim, eps=args.norm_eps) + + if args.use_ffn_learnable_scales and self.mlp_type != "skip": + self.post_ffn_norm = RMSNormWithInputScale(args.dim, eps=args.norm_eps) + @classmethod def from_type(cls, layer_id, args, rope) -> "TransformerBlock": """ @@ -169,21 +214,38 @@ def from_type(cls, layer_id, args, rope) -> "TransformerBlock": mlp_type = args.mlp_type[layer_id] cls = ATTENTION_REGISTRY[args.attention_type] attention = cls(args, layer_id, rope, **args.attention_kwargs) - return TransformerBlock(args, attention, mlp_type=mlp_type) + return TransformerBlock(args, attention, mlp_type=mlp_type, layer_id=layer_id) def forward(self, x, freqs_cos, freqs_sin, attn_options: ForwardOptions): # x: 1xN h, attn_options_update = self.attention( self.attention_norm(x), freqs_cos, freqs_sin, **attn_options ) if not isinstance(self.attention, AttentionSkip): - h = x + h + if self.use_residual_gate: + if hasattr(self, "post_attn_norm"): + h = self.post_attn_norm(h) + h = self.add_attn(stream=x, branch=h) + else: + h = x + h if self.mlp_type == "skip": out = h elif hasattr(self, "block_sparse_moe"): - out = h + self.block_sparse_moe(self.ffn_norm(h)) + ffn_out = self.block_sparse_moe(self.ffn_norm(h)) + if hasattr(self, "post_ffn_norm"): + ffn_out = self.post_ffn_norm(ffn_out) + if self.use_residual_gate: + out = self.add_ffn(stream=h, branch=ffn_out) + else: + out = h + ffn_out else: - out = h + self.feed_forward(self.ffn_norm(h)) + ffn_out = self.feed_forward(self.ffn_norm(h)) + if hasattr(self, "post_ffn_norm"): + ffn_out = self.post_ffn_norm(ffn_out) + if self.use_residual_gate: + out = self.add_ffn(stream=h, branch=ffn_out) + else: + out = h + ffn_out return out, attn_options_update @@ -371,7 +433,9 @@ def construct_transformer(model_args: ModelArgs) -> Transformer: and model_args.layer_types[layer_id] == "skip_attention" ): attention = AttentionSkip() - transformer_block = TransformerBlock(model_args, attention) + transformer_block = TransformerBlock( + model_args, attention, layer_id=layer_id + ) layers.append(transformer_block) elif ( model_args.layer_types @@ -386,13 +450,17 @@ def construct_transformer(model_args: ModelArgs) -> Transformer: attention = linear_cls( model_args, layer_id, rope, **model_args.attention_kwargs ) - transformer_block = TransformerBlock(model_args, attention) + transformer_block = TransformerBlock( + model_args, attention, layer_id=layer_id + ) layers.append(transformer_block) else: attention = cls( model_args, layer_id, rope, **model_args.attention_kwargs ) # pyre-ignore[45] - transformer_block = TransformerBlock(model_args, attention) + transformer_block = TransformerBlock( + model_args, attention, layer_id=layer_id + ) layers.append(transformer_block) return Transformer(model_args, layers, rope) diff --git a/examples/models/llama/model_args.py b/examples/models/llama/model_args.py index 104e9fe2ddd..86818bb721a 100644 --- a/examples/models/llama/model_args.py +++ b/examples/models/llama/model_args.py @@ -102,6 +102,9 @@ class ModelArgs: apply_output: bool = True # Use output layer (unembedding) inside the transformer use_qk_norm: bool = False # apply normalization to q and k in the attention qk_norm_before_rope: bool = False # when to apply qk norm + qk_norm_affine: bool = ( + True # whether QK norm has learnable weight (False = scaleless) + ) residual_multiplier: Optional[float] = ( None # Scaling factor applied to the residual hidden states ) @@ -162,6 +165,15 @@ class ModelArgs: final_logit_softcapping: Optional[float] = None attn_logit_softcapping: Optional[float] = None + # rlformers forward-pass features for on-device model parity + normalize_tok_embeddings: bool = False + scale_query_by: float = 1.0 + use_attn_o_gate: bool = False + use_attn_o_norm: bool = False + use_residual_gate: bool = False + use_ffn_learnable_scales: bool = False + output_soft_cap_temp: Optional[float] = None + def __post_init__(self): # noqa: C901 if self.n_kv_heads is None: self.n_kv_heads = self.n_heads diff --git a/examples/models/llama/norm.py b/examples/models/llama/norm.py index 0189c88b13b..8ad51d4594a 100644 --- a/examples/models/llama/norm.py +++ b/examples/models/llama/norm.py @@ -32,33 +32,38 @@ def __init__(self, dim: int, eps: float = 1e-6, add_unit_offset: bool = False): self.weight = nn.Parameter(torch.ones(dim)) def _norm(self, x): - """ - Apply the RMSNorm normalization to the input tensor. + return x * torch.rsqrt((x * x).mean(-1, keepdim=True) + self.eps) - Args: - x (torch.Tensor): The input tensor. + def forward(self, x): + output = self._norm(x.float()).type_as(x) + if self.add_unit_offset: + return output * (1.0 + self.weight.float()).type_as(x) + return output * self.weight.type_as(x) - Returns: - torch.Tensor: The normalized tensor. - """ - return x * torch.rsqrt((x * x).mean(-1, keepdim=True) + self.eps) +class ScalelessRMSNorm(torch.nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.dim = dim + self.eps = eps def forward(self, x): - """ - Forward pass through the RMSNorm layer. + x_float = x.float() + return ( + x_float * torch.rsqrt((x_float * x_float).mean(-1, keepdim=True) + self.eps) + ).type_as(x) - Args: - x (torch.Tensor): The input tensor. - Returns: - torch.Tensor: The output tensor after applying RMSNorm. +class RMSNormWithInputScale(torch.nn.Module): + def __init__(self, dim: int, eps: float = 1e-5): + super().__init__() + self.eps = eps + self.dim = dim + self.weight = torch.nn.Parameter(torch.ones(dim)) - """ - output = self._norm(x.float()).type_as(x) - if self.add_unit_offset: - return output * (1.0 + self.weight.float()).type_as(x) - return output * self.weight.type_as(x) + def forward(self, x): + scaled = self.weight * x + return F.rms_norm(scaled, (self.dim,), None, self.eps) class RMSNormGated(nn.Module): diff --git a/examples/models/llama/static_attention.py b/examples/models/llama/static_attention.py index a8f342ffa63..72ce31438d6 100644 --- a/examples/models/llama/static_attention.py +++ b/examples/models/llama/static_attention.py @@ -15,6 +15,7 @@ ) from executorch.examples.models.llama.lora import LoRALinear from executorch.examples.models.llama.model_args import ModelArgs +from executorch.examples.models.llama.norm import ScalelessRMSNorm from executorch.examples.models.llama.rope import Rope @@ -797,6 +798,7 @@ def __init__( self.attention_qkv_bias = config.attention_qkv_bias self.use_qk_norm = config.use_qk_norm self.qk_norm_before_rope = config.qk_norm_before_rope + self.scale_query_by = getattr(config, "scale_query_by", 1.0) self.split_mha = split_mha self.is_kv_shared_layer = is_kv_shared_layer self.num_kv_shared_layers = config.num_kv_shared_layers @@ -896,11 +898,18 @@ def _init_wo(self, config: ModelArgs) -> None: def _init_qk_norms(self, config: ModelArgs, is_kv_shared_layer: bool) -> None: if self.use_qk_norm: - self.q_norm = torch.nn.RMSNorm(self.head_dim, config.norm_eps) - if is_kv_shared_layer: - self.k_norm = nn.Identity() + if getattr(config, "qk_norm_affine", True): + self.q_norm = torch.nn.RMSNorm(self.head_dim, config.norm_eps) + if is_kv_shared_layer: + self.k_norm = nn.Identity() + else: + self.k_norm = torch.nn.RMSNorm(self.head_dim, config.norm_eps) else: - self.k_norm = torch.nn.RMSNorm(self.head_dim, config.norm_eps) + self.q_norm = ScalelessRMSNorm(self.head_dim, eps=config.norm_eps) + if is_kv_shared_layer: + self.k_norm = nn.Identity() + else: + self.k_norm = ScalelessRMSNorm(self.head_dim, eps=config.norm_eps) else: self.q_norm = torch.nn.Identity() self.k_norm = torch.nn.Identity() @@ -928,6 +937,18 @@ def from_attention_mha( "contains LoRALinear modules. Use split_mha=False instead." ) + if getattr(other, "use_attn_o_gate", False) or getattr( + other, "use_attn_o_norm", False + ): + raise ValueError( + "StaticAttention does not support use_attn_o_gate or use_attn_o_norm. " + "These features require AttentionMHA." + ) + + qk_norm_affine = ( + hasattr(other.q_norm_fn, "weight") if other.use_qk_norm else True + ) + config = ModelArgs( dim=other.dim, n_layers=1, # Not used in attention layer @@ -939,8 +960,10 @@ def from_attention_mha( attention_qkv_bias=other.attention_qkv_bias, use_qk_norm=other.use_qk_norm, qk_norm_before_rope=other.qk_norm_before_rope, + qk_norm_affine=qk_norm_affine, norm_eps=other.q_norm_fn.eps if other.use_qk_norm else 1e-5, num_kv_shared_layers=getattr(other, "num_kv_shared_layers", 0), + scale_query_by=getattr(other, "scale_query_by", 1.0), ) instance = cls( @@ -1078,7 +1101,7 @@ def _apply_qk_norm(self, qs, ks=None, before_rope=False): before_rope (bool, optional): Whether to apply normalization before RoPE. Defaults to False. """ if self.use_qk_norm and before_rope == self.qk_norm_before_rope: - qs = [self.q_norm(q) for q in qs] + qs = [self.q_norm(q) * self.scale_query_by for q in qs] if ks is not None: ks = [self.k_norm(k) for k in ks] return qs, ks @@ -1306,12 +1329,12 @@ def _forward_mha( q = q.view(bsz, seq_len, self.n_heads, self.head_dim).transpose(1, 2) if self.use_qk_norm and self.qk_norm_before_rope: - q = self.q_norm(q) + q = self.q_norm(q) * self.scale_query_by q = self.rope(q, freqs_cos, freqs_sin) if self.use_qk_norm and not self.qk_norm_before_rope: - q = self.q_norm(q) + q = self.q_norm(q) * self.scale_query_by k, v = shared_kv else: @@ -1320,14 +1343,14 @@ def _forward_mha( v = v.view(bsz, seq_len, self.n_kv_heads, self.head_dim).transpose(1, 2) if self.use_qk_norm and self.qk_norm_before_rope: - q = self.q_norm(q) + q = self.q_norm(q) * self.scale_query_by k = self.k_norm(k) q = self.rope(q, freqs_cos, freqs_sin) k = self.rope(k, freqs_cos, freqs_sin) if self.use_qk_norm and not self.qk_norm_before_rope: - q = self.q_norm(q) + q = self.q_norm(q) * self.scale_query_by k = self.k_norm(k) k, out_cache_state = self.k_caches[0].update( @@ -1430,14 +1453,17 @@ def load_weights_from_attention_mha( if other.use_qk_norm: self.use_qk_norm = True self.qk_norm_before_rope = other.qk_norm_before_rope - self.q_norm = rms_norm_class(other.q_norm_fn.dim, other.q_norm_fn.eps).to( - other.q_norm_fn.weight.dtype - ) - self.q_norm.load_state_dict(other.q_norm_fn.state_dict()) + self.scale_query_by = getattr(other, "scale_query_by", 1.0) + if hasattr(other.q_norm_fn, "weight"): + self.q_norm = rms_norm_class( + other.q_norm_fn.dim, other.q_norm_fn.eps + ).to(other.q_norm_fn.weight.dtype) + self.q_norm.load_state_dict(other.q_norm_fn.state_dict()) if ( not self.is_kv_shared_layer and hasattr(other, "k_norm_fn") and other.k_norm_fn is not None + and hasattr(other.k_norm_fn, "weight") ): self.k_norm = rms_norm_class( other.k_norm_fn.dim, other.k_norm_fn.eps