Skip to content

Commit 18e2b65

Browse files
ifed-ucsdfacebook-github-bot
authored andcommitted
Add rlformers forward-pass features to ExecuTorch backbone for on-device export parity
Summary: The 730M dense model checkpoint (16node_IFT_avoDistilled730m) uses several rlformers features that the ExecuTorch XNNPACK export path did not implement. Without these, the exported model produces numerically incorrect output. This diff adds support for 8 missing features: 1. `normalize_tok_embeddings` — scaleless RMSNorm after embedding lookup 2. `qk_norm_before_rope` — conversion from GenAI args (attention code already supported it) 3. `scale_query_by` — custom scalar multiplier on Q after QK norm 4. `use_attn_o_gate` — sigmoid gate on attention output using a learned linear projection of the layer input 5. `use_attn_o_norm` — scaleless per-head RMSNorm on attention output (applied before o_gate) 6. `use_residual_gate` — NormPreservingResidualConnection with learned per-dim gates for both attention and FFN residual connections 7. `use_ffn_learnable_scales` — RMSNormWithInputScale replacing standard post-FFN norm, computing `rms_norm(gamma * x)` instead of `gamma * rms_norm(x)` 8. `output_soft_cap_temp` — `tanh(logits/temp) * temp` soft capping on output logits All features are off by default (backward compatible). They activate when the corresponding fields are set in the checkpoint's params.json and propagated through model_args_conversion. Weight key mappings added for: `attention.og.weight`, `add_attn.gate`, `add_ffn.gate`, `post_ffn_norm.weight`. Reviewed By: chinnadhurai, digantdesai Differential Revision: D102030169
1 parent 9207001 commit 18e2b65

5 files changed

Lines changed: 172 additions & 53 deletions

File tree

examples/models/llama/attention.py

Lines changed: 45 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import torch.nn.functional as F
88
from executorch.examples.models.llama.lora import LoRALinear
99
from executorch.examples.models.llama.model_args import ModelArgs
10-
from executorch.examples.models.llama.norm import RMSNorm, RMSNormGated
10+
from executorch.examples.models.llama.norm import RMSNorm, RMSNormGated, ScalelessRMSNorm
1111
from executorch.examples.models.llama.rope import Rope
1212

1313

@@ -375,6 +375,9 @@ def __init__(
375375
self.qk_norm_before_rope = args.qk_norm_before_rope
376376
self.use_q_gate = args.use_q_gate
377377
self.enable_dynamic_shape = args.enable_dynamic_shape
378+
self.scale_query_by = args.scale_query_by
379+
self.use_attn_o_gate = args.use_attn_o_gate
380+
self.use_attn_o_norm = args.use_attn_o_norm
378381
q_out_dim = self.n_heads * self.head_dim * (2 if self.use_q_gate else 1)
379382

380383
# YOCO: Determine if this is a KV shared layer (receives shared KV from donor).
@@ -417,17 +420,26 @@ def __init__(
417420
def _init_norms(self, args: ModelArgs) -> None:
418421
"""Initialize QK normalization layers."""
419422
if self.use_qk_norm:
420-
self.q_norm_fn = RMSNorm(
421-
self.head_dim,
422-
eps=args.norm_eps,
423-
add_unit_offset=args.rms_norm_add_unit_offset,
424-
)
425-
if self.has_kv_weights:
426-
self.k_norm_fn = RMSNorm(
423+
if args.qk_norm_affine:
424+
self.q_norm_fn = RMSNorm(
427425
self.head_dim,
428426
eps=args.norm_eps,
429427
add_unit_offset=args.rms_norm_add_unit_offset,
430428
)
429+
if self.has_kv_weights:
430+
self.k_norm_fn = RMSNorm(
431+
self.head_dim,
432+
eps=args.norm_eps,
433+
add_unit_offset=args.rms_norm_add_unit_offset,
434+
)
435+
else:
436+
self.q_norm_fn = ScalelessRMSNorm(self.head_dim, eps=args.norm_eps)
437+
if self.has_kv_weights:
438+
self.k_norm_fn = ScalelessRMSNorm(self.head_dim, eps=args.norm_eps)
439+
if self.use_attn_o_norm:
440+
self.o_norm = ScalelessRMSNorm(self.head_dim, eps=args.norm_eps)
441+
if self.use_attn_o_gate:
442+
self.og = nn.Linear(args.dim, self.n_heads * self.head_dim, bias=False)
431443

432444
def _init_projections(self, args: ModelArgs, q_out_dim: int) -> None:
433445
"""Initialize Q/K/V/O projection layers."""
@@ -478,13 +490,17 @@ def _prepare_qkv_shared(
478490

479491
if self.use_qk_norm and self.qk_norm_before_rope:
480492
q = self.q_norm_fn(q)
493+
if self.scale_query_by != 1.0:
494+
q = q * self.scale_query_by
481495

482496
# Apply RoPE to Q only (K already has RoPE from donor layer)
483497
q, _ = self.rope.forward(q, q, freqs_cos, freqs_sin)
484498
q = q.transpose(1, 2)
485499

486500
if self.use_qk_norm and not self.qk_norm_before_rope:
487501
q = self.q_norm_fn(q)
502+
if self.scale_query_by != 1.0:
503+
q = q * self.scale_query_by
488504

489505
return q, k, v
490506

@@ -508,6 +524,8 @@ def _prepare_qkv(
508524

509525
if self.use_qk_norm and self.qk_norm_before_rope:
510526
q = self.q_norm_fn(q)
527+
if self.scale_query_by != 1.0:
528+
q = q * self.scale_query_by
511529
k = self.k_norm_fn(k)
512530

513531
q, k = self.rope.forward(q, k, freqs_cos, freqs_sin)
@@ -518,6 +536,8 @@ def _prepare_qkv(
518536

519537
if self.use_qk_norm and not self.qk_norm_before_rope:
520538
q = self.q_norm_fn(q)
539+
if self.scale_query_by != 1.0:
540+
q = q * self.scale_query_by
521541
k = self.k_norm_fn(k)
522542

523543
return q, k, v
@@ -582,8 +602,7 @@ def forward(
582602
)
583603

584604
output = self.SDPA(input_pos, q, k, v, bsz, seqlen, attn_mask)
585-
if gate is not None:
586-
output = output * torch.sigmoid(gate)
605+
output = self._apply_output_transforms(output, x, gate, bsz, seqlen)
587606

588607
if shared_kv is None and self.num_kv_shared_layers > 0:
589608
update = {"kv_to_share": (k, v)}
@@ -602,13 +621,27 @@ def forward(
602621
output = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
603622

604623
output = output.transpose(1, 2).reshape(bsz, seqlen, -1)
605-
if gate is not None:
606-
output = output * torch.sigmoid(gate)
624+
output = self._apply_output_transforms(output, x, gate, bsz, seqlen)
607625

608626
output = self.wo(output)
609627

610628
return output, None
611629

630+
def _apply_output_transforms(
631+
self, output: torch.Tensor, x: torch.Tensor, gate, bsz: int, seqlen: int
632+
) -> torch.Tensor:
633+
if self.use_attn_o_norm or self.use_attn_o_gate:
634+
output_4d = output.view(bsz, seqlen, self.n_local_heads, self.head_dim)
635+
if self.use_attn_o_norm:
636+
output_4d = self.o_norm(output_4d)
637+
if self.use_attn_o_gate:
638+
og = self.og(x).view(bsz, seqlen, self.n_local_heads, self.head_dim)
639+
output_4d = torch.sigmoid(og) * output_4d
640+
output = output_4d.reshape(bsz, seqlen, -1)
641+
if gate is not None:
642+
output = output * torch.sigmoid(gate)
643+
return output
644+
612645

613646
def _l2norm(x: torch.Tensor, dim: int = -1, eps: float = 1e-6) -> torch.Tensor:
614647
inv_norm = torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps)

examples/models/llama/llama_transformer.py

Lines changed: 58 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
# Please refer to README.md in the same folder for more information.
99

10+
import math
1011
from typing import Any, Dict, Optional, Tuple, Union
1112

1213
import torch
@@ -20,7 +21,7 @@
2021
)
2122
from executorch.examples.models.llama.feed_forward import FeedForward, LoRAFeedForward
2223
from executorch.examples.models.llama.model_args import ModelArgs
23-
from executorch.examples.models.llama.norm import RMSNorm
24+
from executorch.examples.models.llama.norm import RMSNorm, RMSNormWithInputScale, ScalelessRMSNorm
2425
from executorch.examples.models.llama.rope import Rope
2526
from torch import nn
2627

@@ -51,6 +52,24 @@ def _is_kv_shared_layer(
5152
return layer_idx >= first_shared and first_shared > 0
5253

5354

55+
class NormPreservingResidualConnection(nn.Module):
56+
def __init__(self, dim: int, init_scale: float, temperature: float = 0.3, eps: float = 1e-3):
57+
super().__init__()
58+
self.eps = eps
59+
self.temperature = temperature
60+
p = max(0.0 + eps, min(1.0 - eps, init_scale))
61+
init_param = math.log(p / (1.0 - p)) * temperature
62+
self.gate = nn.Parameter(torch.full((dim,), init_param))
63+
64+
def forward(self, stream: torch.Tensor, branch: torch.Tensor) -> torch.Tensor:
65+
dtype = stream.dtype
66+
w = self.gate.view(*([1] * (stream.ndim - 1)), -1).float()
67+
beta = torch.sigmoid(w / self.temperature)
68+
alpha_sq = torch.sigmoid(-w / self.temperature) * (1.0 + beta)
69+
alpha = torch.sqrt(torch.clamp(alpha_sq, min=self.eps))
70+
return (alpha * stream.float() + beta * branch.float()).to(dtype)
71+
72+
5473
class ConditionalFeedForward(nn.Module):
5574
def __init__(self, args: ModelArgs):
5675
super().__init__()
@@ -99,7 +118,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
99118

100119
class TransformerBlock(nn.Module):
101120
def __init__(
102-
self, args: ModelArgs, attention: Attention, mlp_type: str = "default"
121+
self, args: ModelArgs, attention: Attention, mlp_type: str = "default",
122+
layer_id: int = 0,
103123
):
104124
"""
105125
Transformer block with support for pre-norm and post-norm.
@@ -110,6 +130,7 @@ def __init__(
110130
the attention type is registered in the ATTENTION_REGISTRY.
111131
mlp_type (str): MLP type for this layer. "default" for standard
112132
FFN, "skip" for no FFN block.
133+
layer_id (int): layer index, used for residual gate initialization.
113134
"""
114135
super().__init__()
115136
self.use_kv_cache = args.use_kv_cache
@@ -118,6 +139,7 @@ def __init__(
118139
self.head_dim = args.head_dim
119140
self.attention = attention
120141
self.mlp_type = mlp_type.lower()
142+
self.use_residual_gate = args.use_residual_gate
121143

122144
assert (
123145
args.hidden_dim is not None
@@ -150,6 +172,16 @@ def __init__(
150172
add_unit_offset=args.rms_norm_add_unit_offset,
151173
)
152174

175+
if args.use_residual_gate:
176+
attn_init = 1.0 / (2 * layer_id + 1) if layer_id > 0 else 0.5
177+
ffn_init = 1.0 / (2 * layer_id + 2)
178+
self.add_attn = NormPreservingResidualConnection(dim=args.dim, init_scale=attn_init)
179+
self.add_ffn = NormPreservingResidualConnection(dim=args.dim, init_scale=ffn_init)
180+
self.post_attn_norm = ScalelessRMSNorm(args.dim, eps=args.norm_eps)
181+
182+
if args.use_ffn_learnable_scales and self.mlp_type != "skip":
183+
self.post_ffn_norm = RMSNormWithInputScale(args.dim, eps=args.norm_eps)
184+
153185
@classmethod
154186
def from_type(cls, layer_id, args, rope) -> "TransformerBlock":
155187
"""
@@ -169,21 +201,38 @@ def from_type(cls, layer_id, args, rope) -> "TransformerBlock":
169201
mlp_type = args.mlp_type[layer_id]
170202
cls = ATTENTION_REGISTRY[args.attention_type]
171203
attention = cls(args, layer_id, rope, **args.attention_kwargs)
172-
return TransformerBlock(args, attention, mlp_type=mlp_type)
204+
return TransformerBlock(args, attention, mlp_type=mlp_type, layer_id=layer_id)
173205

174206
def forward(self, x, freqs_cos, freqs_sin, attn_options: ForwardOptions): # x: 1xN
175207
h, attn_options_update = self.attention(
176208
self.attention_norm(x), freqs_cos, freqs_sin, **attn_options
177209
)
178210
if not isinstance(self.attention, AttentionSkip):
179-
h = x + h
211+
if self.use_residual_gate:
212+
if hasattr(self, "post_attn_norm"):
213+
h = self.post_attn_norm(h)
214+
h = self.add_attn(stream=x, branch=h)
215+
else:
216+
h = x + h
180217

181218
if self.mlp_type == "skip":
182219
out = h
183220
elif hasattr(self, "block_sparse_moe"):
184-
out = h + self.block_sparse_moe(self.ffn_norm(h))
221+
ffn_out = self.block_sparse_moe(self.ffn_norm(h))
222+
if hasattr(self, "post_ffn_norm"):
223+
ffn_out = self.post_ffn_norm(ffn_out)
224+
if self.use_residual_gate:
225+
out = self.add_ffn(stream=h, branch=ffn_out)
226+
else:
227+
out = h + ffn_out
185228
else:
186-
out = h + self.feed_forward(self.ffn_norm(h))
229+
ffn_out = self.feed_forward(self.ffn_norm(h))
230+
if hasattr(self, "post_ffn_norm"):
231+
ffn_out = self.post_ffn_norm(ffn_out)
232+
if self.use_residual_gate:
233+
out = self.add_ffn(stream=h, branch=ffn_out)
234+
else:
235+
out = h + ffn_out
187236
return out, attn_options_update
188237

189238

@@ -371,7 +420,7 @@ def construct_transformer(model_args: ModelArgs) -> Transformer:
371420
and model_args.layer_types[layer_id] == "skip_attention"
372421
):
373422
attention = AttentionSkip()
374-
transformer_block = TransformerBlock(model_args, attention)
423+
transformer_block = TransformerBlock(model_args, attention, layer_id=layer_id)
375424
layers.append(transformer_block)
376425
elif (
377426
model_args.layer_types
@@ -386,13 +435,13 @@ def construct_transformer(model_args: ModelArgs) -> Transformer:
386435
attention = linear_cls(
387436
model_args, layer_id, rope, **model_args.attention_kwargs
388437
)
389-
transformer_block = TransformerBlock(model_args, attention)
438+
transformer_block = TransformerBlock(model_args, attention, layer_id=layer_id)
390439
layers.append(transformer_block)
391440
else:
392441
attention = cls(
393442
model_args, layer_id, rope, **model_args.attention_kwargs
394443
) # pyre-ignore[45]
395-
transformer_block = TransformerBlock(model_args, attention)
444+
transformer_block = TransformerBlock(model_args, attention, layer_id=layer_id)
396445
layers.append(transformer_block)
397446

398447
return Transformer(model_args, layers, rope)

examples/models/llama/model_args.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ class ModelArgs:
102102
apply_output: bool = True # Use output layer (unembedding) inside the transformer
103103
use_qk_norm: bool = False # apply normalization to q and k in the attention
104104
qk_norm_before_rope: bool = False # when to apply qk norm
105+
qk_norm_affine: bool = True # whether QK norm has learnable weight (False = scaleless)
105106
residual_multiplier: Optional[float] = (
106107
None # Scaling factor applied to the residual hidden states
107108
)
@@ -162,6 +163,15 @@ class ModelArgs:
162163
final_logit_softcapping: Optional[float] = None
163164
attn_logit_softcapping: Optional[float] = None
164165

166+
# rlformers forward-pass features for on-device model parity
167+
normalize_tok_embeddings: bool = False
168+
scale_query_by: float = 1.0
169+
use_attn_o_gate: bool = False
170+
use_attn_o_norm: bool = False
171+
use_residual_gate: bool = False
172+
use_ffn_learnable_scales: bool = False
173+
output_soft_cap_temp: Optional[float] = None
174+
165175
def __post_init__(self): # noqa: C901
166176
if self.n_kv_heads is None:
167177
self.n_kv_heads = self.n_heads

examples/models/llama/norm.py

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -32,33 +32,36 @@ def __init__(self, dim: int, eps: float = 1e-6, add_unit_offset: bool = False):
3232
self.weight = nn.Parameter(torch.ones(dim))
3333

3434
def _norm(self, x):
35-
"""
36-
Apply the RMSNorm normalization to the input tensor.
35+
return x * torch.rsqrt((x * x).mean(-1, keepdim=True) + self.eps)
3736

38-
Args:
39-
x (torch.Tensor): The input tensor.
37+
def forward(self, x):
38+
output = self._norm(x.float()).type_as(x)
39+
if self.add_unit_offset:
40+
return output * (1.0 + self.weight.float()).type_as(x)
41+
return output * self.weight.type_as(x)
4042

41-
Returns:
42-
torch.Tensor: The normalized tensor.
4343

44-
"""
45-
return x * torch.rsqrt((x * x).mean(-1, keepdim=True) + self.eps)
44+
class ScalelessRMSNorm(torch.nn.Module):
45+
def __init__(self, dim: int, eps: float = 1e-6):
46+
super().__init__()
47+
self.dim = dim
48+
self.eps = eps
4649

4750
def forward(self, x):
48-
"""
49-
Forward pass through the RMSNorm layer.
51+
x_float = x.float()
52+
return (x_float * torch.rsqrt((x_float * x_float).mean(-1, keepdim=True) + self.eps)).type_as(x)
5053

51-
Args:
52-
x (torch.Tensor): The input tensor.
5354

54-
Returns:
55-
torch.Tensor: The output tensor after applying RMSNorm.
55+
class RMSNormWithInputScale(torch.nn.Module):
56+
def __init__(self, dim: int, eps: float = 1e-5):
57+
super().__init__()
58+
self.eps = eps
59+
self.dim = dim
60+
self.weight = torch.nn.Parameter(torch.ones(dim))
5661

57-
"""
58-
output = self._norm(x.float()).type_as(x)
59-
if self.add_unit_offset:
60-
return output * (1.0 + self.weight.float()).type_as(x)
61-
return output * self.weight.type_as(x)
62+
def forward(self, x):
63+
scaled = (self.weight * x).float()
64+
return (scaled * torch.rsqrt((scaled * scaled).mean(-1, keepdim=True) + self.eps)).type_as(x)
6265

6366

6467
class RMSNormGated(nn.Module):

0 commit comments

Comments
 (0)