Skip to content

Commit 6d2a84a

Browse files
ifed-ucsdfacebook-github-bot
authored andcommitted
Add rlformers forward-pass features to ExecuTorch backbone for on-device export parity (#19096)
Summary: The 730M dense model checkpoint uses several rlformers features that the ExecuTorch XNNPACK export path did not implement. Without these, the exported model produces numerically incorrect output. This diff adds support for 8 missing features: 1. `normalize_tok_embeddings` — scaleless RMSNorm after embedding lookup 2. `qk_norm_before_rope` — conversion from GenAI args (attention code already supported it) 3. `scale_query_by` — custom scalar multiplier on Q after QK norm 4. `use_attn_o_gate` — sigmoid gate on attention output using a learned linear projection of the layer input 5. `use_attn_o_norm` — scaleless per-head RMSNorm on attention output (applied before o_gate) 6. `use_residual_gate` — NormPreservingResidualConnection with learned per-dim gates for both attention and FFN residual connections 7. `use_ffn_learnable_scales` — RMSNormWithInputScale replacing standard post-FFN norm, computing `rms_norm(gamma * x)` instead of `gamma * rms_norm(x)` 8. `output_soft_cap_temp` — `tanh(logits/temp) * temp` soft capping on output logits All features are off by default (backward compatible). They activate when the corresponding fields are set in the checkpoint's params.json and propagated through model_args_conversion. Weight key mappings added for: `attention.og.weight`, `add_attn.gate`, `add_ffn.gate`, `post_ffn_norm.weight`. Differential Revision: D102030169
1 parent f9f29e7 commit 6d2a84a

4 files changed

Lines changed: 129 additions & 44 deletions

File tree

examples/models/llama/attention.py

Lines changed: 41 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import torch.nn.functional as F
88
from executorch.examples.models.llama.lora import LoRALinear
99
from executorch.examples.models.llama.model_args import ModelArgs
10-
from executorch.examples.models.llama.norm import RMSNorm, RMSNormGated
10+
from executorch.examples.models.llama.norm import RMSNorm, RMSNormGated, ScalelessRMSNorm
1111
from executorch.examples.models.llama.rope import Rope
1212

1313

@@ -375,6 +375,9 @@ def __init__(
375375
self.qk_norm_before_rope = args.qk_norm_before_rope
376376
self.use_q_gate = args.use_q_gate
377377
self.enable_dynamic_shape = args.enable_dynamic_shape
378+
self.scale_query_by = args.scale_query_by
379+
self.use_attn_o_gate = args.use_attn_o_gate
380+
self.use_attn_o_norm = args.use_attn_o_norm
378381
q_out_dim = self.n_heads * self.head_dim * (2 if self.use_q_gate else 1)
379382

380383
# YOCO: Determine if this is a KV shared layer (receives shared KV from donor).
@@ -417,17 +420,26 @@ def __init__(
417420
def _init_norms(self, args: ModelArgs) -> None:
418421
"""Initialize QK normalization layers."""
419422
if self.use_qk_norm:
420-
self.q_norm_fn = RMSNorm(
421-
self.head_dim,
422-
eps=args.norm_eps,
423-
add_unit_offset=args.rms_norm_add_unit_offset,
424-
)
425-
if self.has_kv_weights:
426-
self.k_norm_fn = RMSNorm(
423+
if args.qk_norm_affine:
424+
self.q_norm_fn = RMSNorm(
427425
self.head_dim,
428426
eps=args.norm_eps,
429427
add_unit_offset=args.rms_norm_add_unit_offset,
430428
)
429+
if self.has_kv_weights:
430+
self.k_norm_fn = RMSNorm(
431+
self.head_dim,
432+
eps=args.norm_eps,
433+
add_unit_offset=args.rms_norm_add_unit_offset,
434+
)
435+
else:
436+
self.q_norm_fn = ScalelessRMSNorm(self.head_dim, eps=args.norm_eps)
437+
if self.has_kv_weights:
438+
self.k_norm_fn = ScalelessRMSNorm(self.head_dim, eps=args.norm_eps)
439+
if self.use_attn_o_norm:
440+
self.o_norm = ScalelessRMSNorm(self.head_dim, eps=args.norm_eps)
441+
if self.use_attn_o_gate:
442+
self.og = nn.Linear(args.dim, self.n_heads * self.head_dim, bias=False)
431443

432444
def _init_projections(self, args: ModelArgs, q_out_dim: int) -> None:
433445
"""Initialize Q/K/V/O projection layers."""
@@ -477,14 +489,14 @@ def _prepare_qkv_shared(
477489
k, v = shared_kv
478490

479491
if self.use_qk_norm and self.qk_norm_before_rope:
480-
q = self.q_norm_fn(q)
492+
q = self.q_norm_fn(q) * self.scale_query_by
481493

482494
# Apply RoPE to Q only (K already has RoPE from donor layer)
483495
q, _ = self.rope.forward(q, q, freqs_cos, freqs_sin)
484496
q = q.transpose(1, 2)
485497

486498
if self.use_qk_norm and not self.qk_norm_before_rope:
487-
q = self.q_norm_fn(q)
499+
q = self.q_norm_fn(q) * self.scale_query_by
488500

489501
return q, k, v
490502

@@ -507,7 +519,7 @@ def _prepare_qkv(
507519
v = v.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
508520

509521
if self.use_qk_norm and self.qk_norm_before_rope:
510-
q = self.q_norm_fn(q)
522+
q = self.q_norm_fn(q) * self.scale_query_by
511523
k = self.k_norm_fn(k)
512524

513525
q, k = self.rope.forward(q, k, freqs_cos, freqs_sin)
@@ -517,7 +529,7 @@ def _prepare_qkv(
517529
v = v.transpose(1, 2)
518530

519531
if self.use_qk_norm and not self.qk_norm_before_rope:
520-
q = self.q_norm_fn(q)
532+
q = self.q_norm_fn(q) * self.scale_query_by
521533
k = self.k_norm_fn(k)
522534

523535
return q, k, v
@@ -582,8 +594,7 @@ def forward(
582594
)
583595

584596
output = self.SDPA(input_pos, q, k, v, bsz, seqlen, attn_mask)
585-
if gate is not None:
586-
output = output * torch.sigmoid(gate)
597+
output = self._apply_output_transforms(output, x, gate, bsz, seqlen)
587598

588599
if shared_kv is None and self.num_kv_shared_layers > 0:
589600
update = {"kv_to_share": (k, v)}
@@ -602,13 +613,27 @@ def forward(
602613
output = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
603614

604615
output = output.transpose(1, 2).reshape(bsz, seqlen, -1)
605-
if gate is not None:
606-
output = output * torch.sigmoid(gate)
616+
output = self._apply_output_transforms(output, x, gate, bsz, seqlen)
607617

608618
output = self.wo(output)
609619

610620
return output, None
611621

622+
def _apply_output_transforms(
623+
self, output: torch.Tensor, x: torch.Tensor, gate, bsz: int, seqlen: int
624+
) -> torch.Tensor:
625+
if self.use_attn_o_norm or self.use_attn_o_gate:
626+
output_4d = output.view(bsz, seqlen, self.n_local_heads, self.head_dim)
627+
if self.use_attn_o_norm:
628+
output_4d = self.o_norm(output_4d)
629+
if self.use_attn_o_gate:
630+
og = self.og(x).view(bsz, seqlen, self.n_local_heads, self.head_dim)
631+
output_4d = torch.sigmoid(og) * output_4d
632+
output = output_4d.reshape(bsz, seqlen, -1)
633+
if gate is not None:
634+
output = output * torch.sigmoid(gate)
635+
return output
636+
612637

613638
def _l2norm(x: torch.Tensor, dim: int = -1, eps: float = 1e-6) -> torch.Tensor:
614639
inv_norm = torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps)

examples/models/llama/llama_transformer.py

Lines changed: 57 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
# Please refer to README.md in the same folder for more information.
99

10+
import math
1011
from typing import Any, Dict, Optional, Tuple, Union
1112

1213
import torch
@@ -20,7 +21,7 @@
2021
)
2122
from executorch.examples.models.llama.feed_forward import FeedForward, LoRAFeedForward
2223
from executorch.examples.models.llama.model_args import ModelArgs
23-
from executorch.examples.models.llama.norm import RMSNorm
24+
from executorch.examples.models.llama.norm import RMSNorm, RMSNormWithInputScale, ScalelessRMSNorm
2425
from executorch.examples.models.llama.rope import Rope
2526
from torch import nn
2627

@@ -51,6 +52,23 @@ def _is_kv_shared_layer(
5152
return layer_idx >= first_shared and first_shared > 0
5253

5354

55+
class NormPreservingResidualConnection(nn.Module):
56+
def __init__(self, dim: int, init_scale: float, temperature: float = 0.3, eps: float = 1e-3):
57+
super().__init__()
58+
self.eps = eps
59+
self.temperature = temperature
60+
p = max(0.0 + eps, min(1.0 - eps, init_scale))
61+
init_param = math.log(p / (1.0 - p)) * temperature
62+
self.gate = nn.Parameter(torch.full((dim,), init_param))
63+
64+
def forward(self, stream: torch.Tensor, branch: torch.Tensor) -> torch.Tensor:
65+
w = self.gate.view(*([1] * (stream.ndim - 1)), -1)
66+
beta = torch.sigmoid(w / self.temperature)
67+
alpha_sq = torch.sigmoid(-w / self.temperature) * (1.0 + beta)
68+
alpha = torch.sqrt(torch.clamp(alpha_sq, min=self.eps))
69+
return alpha * stream + beta * branch
70+
71+
5472
class ConditionalFeedForward(nn.Module):
5573
def __init__(self, args: ModelArgs):
5674
super().__init__()
@@ -99,7 +117,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
99117

100118
class TransformerBlock(nn.Module):
101119
def __init__(
102-
self, args: ModelArgs, attention: Attention, mlp_type: str = "default"
120+
self, args: ModelArgs, attention: Attention, mlp_type: str = "default",
121+
layer_id: int = 0,
103122
):
104123
"""
105124
Transformer block with support for pre-norm and post-norm.
@@ -110,6 +129,7 @@ def __init__(
110129
the attention type is registered in the ATTENTION_REGISTRY.
111130
mlp_type (str): MLP type for this layer. "default" for standard
112131
FFN, "skip" for no FFN block.
132+
layer_id (int): layer index, used for residual gate initialization.
113133
"""
114134
super().__init__()
115135
self.use_kv_cache = args.use_kv_cache
@@ -118,6 +138,7 @@ def __init__(
118138
self.head_dim = args.head_dim
119139
self.attention = attention
120140
self.mlp_type = mlp_type.lower()
141+
self.use_residual_gate = args.use_residual_gate
121142

122143
assert (
123144
args.hidden_dim is not None
@@ -150,6 +171,16 @@ def __init__(
150171
add_unit_offset=args.rms_norm_add_unit_offset,
151172
)
152173

174+
if args.use_residual_gate:
175+
attn_init = 1.0 / (2 * layer_id + 1) if layer_id > 0 else 0.5
176+
ffn_init = 1.0 / (2 * layer_id + 2)
177+
self.add_attn = NormPreservingResidualConnection(dim=args.dim, init_scale=attn_init)
178+
self.add_ffn = NormPreservingResidualConnection(dim=args.dim, init_scale=ffn_init)
179+
self.post_attn_norm = ScalelessRMSNorm(args.dim, eps=args.norm_eps)
180+
181+
if args.use_ffn_learnable_scales and self.mlp_type != "skip":
182+
self.post_ffn_norm = RMSNormWithInputScale(args.dim, eps=args.norm_eps)
183+
153184
@classmethod
154185
def from_type(cls, layer_id, args, rope) -> "TransformerBlock":
155186
"""
@@ -169,21 +200,38 @@ def from_type(cls, layer_id, args, rope) -> "TransformerBlock":
169200
mlp_type = args.mlp_type[layer_id]
170201
cls = ATTENTION_REGISTRY[args.attention_type]
171202
attention = cls(args, layer_id, rope, **args.attention_kwargs)
172-
return TransformerBlock(args, attention, mlp_type=mlp_type)
203+
return TransformerBlock(args, attention, mlp_type=mlp_type, layer_id=layer_id)
173204

174205
def forward(self, x, freqs_cos, freqs_sin, attn_options: ForwardOptions): # x: 1xN
175206
h, attn_options_update = self.attention(
176207
self.attention_norm(x), freqs_cos, freqs_sin, **attn_options
177208
)
178209
if not isinstance(self.attention, AttentionSkip):
179-
h = x + h
210+
if self.use_residual_gate:
211+
if hasattr(self, "post_attn_norm"):
212+
h = self.post_attn_norm(h)
213+
h = self.add_attn(stream=x, branch=h)
214+
else:
215+
h = x + h
180216

181217
if self.mlp_type == "skip":
182218
out = h
183219
elif hasattr(self, "block_sparse_moe"):
184-
out = h + self.block_sparse_moe(self.ffn_norm(h))
220+
ffn_out = self.block_sparse_moe(self.ffn_norm(h))
221+
if hasattr(self, "post_ffn_norm"):
222+
ffn_out = self.post_ffn_norm(ffn_out)
223+
if self.use_residual_gate:
224+
out = self.add_ffn(stream=h, branch=ffn_out)
225+
else:
226+
out = h + ffn_out
185227
else:
186-
out = h + self.feed_forward(self.ffn_norm(h))
228+
ffn_out = self.feed_forward(self.ffn_norm(h))
229+
if hasattr(self, "post_ffn_norm"):
230+
ffn_out = self.post_ffn_norm(ffn_out)
231+
if self.use_residual_gate:
232+
out = self.add_ffn(stream=h, branch=ffn_out)
233+
else:
234+
out = h + ffn_out
187235
return out, attn_options_update
188236

189237

@@ -371,7 +419,7 @@ def construct_transformer(model_args: ModelArgs) -> Transformer:
371419
and model_args.layer_types[layer_id] == "skip_attention"
372420
):
373421
attention = AttentionSkip()
374-
transformer_block = TransformerBlock(model_args, attention)
422+
transformer_block = TransformerBlock(model_args, attention, layer_id=layer_id)
375423
layers.append(transformer_block)
376424
elif (
377425
model_args.layer_types
@@ -386,13 +434,13 @@ def construct_transformer(model_args: ModelArgs) -> Transformer:
386434
attention = linear_cls(
387435
model_args, layer_id, rope, **model_args.attention_kwargs
388436
)
389-
transformer_block = TransformerBlock(model_args, attention)
437+
transformer_block = TransformerBlock(model_args, attention, layer_id=layer_id)
390438
layers.append(transformer_block)
391439
else:
392440
attention = cls(
393441
model_args, layer_id, rope, **model_args.attention_kwargs
394442
) # pyre-ignore[45]
395-
transformer_block = TransformerBlock(model_args, attention)
443+
transformer_block = TransformerBlock(model_args, attention, layer_id=layer_id)
396444
layers.append(transformer_block)
397445

398446
return Transformer(model_args, layers, rope)

examples/models/llama/model_args.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ class ModelArgs:
102102
apply_output: bool = True # Use output layer (unembedding) inside the transformer
103103
use_qk_norm: bool = False # apply normalization to q and k in the attention
104104
qk_norm_before_rope: bool = False # when to apply qk norm
105+
qk_norm_affine: bool = True # whether QK norm has learnable weight (False = scaleless)
105106
residual_multiplier: Optional[float] = (
106107
None # Scaling factor applied to the residual hidden states
107108
)
@@ -162,6 +163,15 @@ class ModelArgs:
162163
final_logit_softcapping: Optional[float] = None
163164
attn_logit_softcapping: Optional[float] = None
164165

166+
# rlformers forward-pass features for on-device model parity
167+
normalize_tok_embeddings: bool = False
168+
scale_query_by: float = 1.0
169+
use_attn_o_gate: bool = False
170+
use_attn_o_norm: bool = False
171+
use_residual_gate: bool = False
172+
use_ffn_learnable_scales: bool = False
173+
output_soft_cap_temp: Optional[float] = None
174+
165175
def __post_init__(self): # noqa: C901
166176
if self.n_kv_heads is None:
167177
self.n_kv_heads = self.n_heads

examples/models/llama/norm.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -32,33 +32,35 @@ def __init__(self, dim: int, eps: float = 1e-6, add_unit_offset: bool = False):
3232
self.weight = nn.Parameter(torch.ones(dim))
3333

3434
def _norm(self, x):
35-
"""
36-
Apply the RMSNorm normalization to the input tensor.
35+
return x * torch.rsqrt((x * x).mean(-1, keepdim=True) + self.eps)
3736

38-
Args:
39-
x (torch.Tensor): The input tensor.
37+
def forward(self, x):
38+
output = self._norm(x.float()).type_as(x)
39+
if self.add_unit_offset:
40+
return output * (1.0 + self.weight.float()).type_as(x)
41+
return output * self.weight.type_as(x)
4042

41-
Returns:
42-
torch.Tensor: The normalized tensor.
4343

44-
"""
45-
return x * torch.rsqrt((x * x).mean(-1, keepdim=True) + self.eps)
44+
class ScalelessRMSNorm(torch.nn.Module):
45+
def __init__(self, dim: int, eps: float = 1e-6):
46+
super().__init__()
47+
self.dim = dim
48+
self.eps = eps
4649

4750
def forward(self, x):
48-
"""
49-
Forward pass through the RMSNorm layer.
51+
return x * torch.rsqrt((x * x).mean(-1, keepdim=True) + self.eps)
5052

51-
Args:
52-
x (torch.Tensor): The input tensor.
5353

54-
Returns:
55-
torch.Tensor: The output tensor after applying RMSNorm.
54+
class RMSNormWithInputScale(torch.nn.Module):
55+
def __init__(self, dim: int, eps: float = 1e-5):
56+
super().__init__()
57+
self.eps = eps
58+
self.dim = dim
59+
self.weight = torch.nn.Parameter(torch.ones(dim))
5660

57-
"""
58-
output = self._norm(x.float()).type_as(x)
59-
if self.add_unit_offset:
60-
return output * (1.0 + self.weight.float()).type_as(x)
61-
return output * self.weight.type_as(x)
61+
def forward(self, x):
62+
scaled = self.weight * x
63+
return scaled * torch.rsqrt((scaled * scaled).mean(-1, keepdim=True) + self.eps)
6264

6365

6466
class RMSNormGated(nn.Module):

0 commit comments

Comments
 (0)