Skip to content

Commit e686bed

Browse files
ifed-ucsdfacebook-github-bot
authored andcommitted
Add rlformers forward-pass features to ExecuTorch backbone for on-device export parity
Summary: The 730M dense model checkpoint (16node_IFT_avoDistilled730m) uses several rlformers features that the ExecuTorch XNNPACK export path did not implement. Without these, the exported model produces numerically incorrect output. This diff adds support for 8 missing features: 1. `normalize_tok_embeddings` — scaleless RMSNorm after embedding lookup 2. `qk_norm_before_rope` — conversion from GenAI args (attention code already supported it) 3. `scale_query_by` — custom scalar multiplier on Q after QK norm 4. `use_attn_o_gate` — sigmoid gate on attention output using a learned linear projection of the layer input 5. `use_attn_o_norm` — scaleless per-head RMSNorm on attention output (applied before o_gate) 6. `use_residual_gate` — NormPreservingResidualConnection with learned per-dim gates for both attention and FFN residual connections 7. `use_ffn_learnable_scales` — RMSNormWithInputScale replacing standard post-FFN norm, computing `rms_norm(gamma * x)` instead of `gamma * rms_norm(x)` 8. `output_soft_cap_temp` — `tanh(logits/temp) * temp` soft capping on output logits All features are off by default (backward compatible). They activate when the corresponding fields are set in the checkpoint's params.json and propagated through model_args_conversion. Weight key mappings added for: `attention.og.weight`, `add_attn.gate`, `add_ffn.gate`, `post_ffn_norm.weight`. Reviewed By: chinnadhurai, digantdesai Differential Revision: D102030169
1 parent d0b7934 commit e686bed

5 files changed

Lines changed: 168 additions & 57 deletions

File tree

examples/models/llama/attention.py

Lines changed: 41 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import torch.nn.functional as F
99
from executorch.examples.models.llama.lora import LoRALinear
1010
from executorch.examples.models.llama.model_args import ModelArgs
11-
from executorch.examples.models.llama.norm import RMSNorm, RMSNormGated
11+
from executorch.examples.models.llama.norm import RMSNorm, RMSNormGated, ScalelessRMSNorm
1212
from executorch.examples.models.llama.rope import Rope
1313

1414

@@ -410,6 +410,9 @@ def __init__(
410410
self.qk_norm_before_rope = args.qk_norm_before_rope
411411
self.use_q_gate = args.use_q_gate
412412
self.enable_dynamic_shape = args.enable_dynamic_shape
413+
self.scale_query_by = args.scale_query_by
414+
self.use_attn_o_gate = args.use_attn_o_gate
415+
self.use_attn_o_norm = args.use_attn_o_norm
413416
q_out_dim = self.n_heads * self.head_dim * (2 if self.use_q_gate else 1)
414417

415418
# YOCO: Determine if this is a KV shared layer (receives shared KV from donor).
@@ -452,17 +455,26 @@ def __init__(
452455
def _init_norms(self, args: ModelArgs) -> None:
453456
"""Initialize QK normalization layers."""
454457
if self.use_qk_norm:
455-
self.q_norm_fn = RMSNorm(
456-
self.head_dim,
457-
eps=args.norm_eps,
458-
add_unit_offset=args.rms_norm_add_unit_offset,
459-
)
460-
if self.has_kv_weights:
461-
self.k_norm_fn = RMSNorm(
458+
if args.qk_norm_affine:
459+
self.q_norm_fn = RMSNorm(
462460
self.head_dim,
463461
eps=args.norm_eps,
464462
add_unit_offset=args.rms_norm_add_unit_offset,
465463
)
464+
if self.has_kv_weights:
465+
self.k_norm_fn = RMSNorm(
466+
self.head_dim,
467+
eps=args.norm_eps,
468+
add_unit_offset=args.rms_norm_add_unit_offset,
469+
)
470+
else:
471+
self.q_norm_fn = ScalelessRMSNorm(self.head_dim, eps=args.norm_eps)
472+
if self.has_kv_weights:
473+
self.k_norm_fn = ScalelessRMSNorm(self.head_dim, eps=args.norm_eps)
474+
if self.use_attn_o_norm:
475+
self.o_norm = ScalelessRMSNorm(self.head_dim, eps=args.norm_eps)
476+
if self.use_attn_o_gate:
477+
self.og = nn.Linear(args.dim, self.n_heads * self.head_dim, bias=False)
466478

467479
def _init_projections(self, args: ModelArgs, q_out_dim: int) -> None:
468480
"""Initialize Q/K/V/O projection layers."""
@@ -512,14 +524,14 @@ def _prepare_qkv_shared(
512524
k, v = shared_kv
513525

514526
if self.use_qk_norm and self.qk_norm_before_rope:
515-
q = self.q_norm_fn(q)
527+
q = self.q_norm_fn(q) * self.scale_query_by
516528

517529
# Apply RoPE to Q only (K already has RoPE from donor layer)
518530
q, _ = self.rope.forward(q, q, freqs_cos, freqs_sin)
519531
q = q.transpose(1, 2)
520532

521533
if self.use_qk_norm and not self.qk_norm_before_rope:
522-
q = self.q_norm_fn(q)
534+
q = self.q_norm_fn(q) * self.scale_query_by
523535

524536
return q, k, v
525537

@@ -542,7 +554,7 @@ def _prepare_qkv(
542554
v = v.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
543555

544556
if self.use_qk_norm and self.qk_norm_before_rope:
545-
q = self.q_norm_fn(q)
557+
q = self.q_norm_fn(q) * self.scale_query_by
546558
k = self.k_norm_fn(k)
547559

548560
q, k = self.rope.forward(q, k, freqs_cos, freqs_sin)
@@ -552,7 +564,7 @@ def _prepare_qkv(
552564
v = v.transpose(1, 2)
553565

554566
if self.use_qk_norm and not self.qk_norm_before_rope:
555-
q = self.q_norm_fn(q)
567+
q = self.q_norm_fn(q) * self.scale_query_by
556568
k = self.k_norm_fn(k)
557569

558570
return q, k, v
@@ -617,8 +629,7 @@ def forward(
617629
)
618630

619631
output = self.SDPA(input_pos, q, k, v, bsz, seqlen, attn_mask)
620-
if gate is not None:
621-
output = output * torch.sigmoid(gate)
632+
output = self._apply_output_transforms(output, x, gate, bsz, seqlen)
622633

623634
if shared_kv is None and self.num_kv_shared_layers > 0:
624635
update = {"kv_to_share": (k, v)}
@@ -637,13 +648,27 @@ def forward(
637648
output = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
638649

639650
output = output.transpose(1, 2).reshape(bsz, seqlen, -1)
640-
if gate is not None:
641-
output = output * torch.sigmoid(gate)
651+
output = self._apply_output_transforms(output, x, gate, bsz, seqlen)
642652

643653
output = self.wo(output)
644654

645655
return output, None
646656

657+
def _apply_output_transforms(
658+
self, output: torch.Tensor, x: torch.Tensor, gate, bsz: int, seqlen: int
659+
) -> torch.Tensor:
660+
if self.use_attn_o_norm or self.use_attn_o_gate:
661+
output_4d = output.view(bsz, seqlen, self.n_local_heads, self.head_dim)
662+
if self.use_attn_o_norm:
663+
output_4d = self.o_norm(output_4d)
664+
if self.use_attn_o_gate:
665+
og = self.og(x).view(bsz, seqlen, self.n_local_heads, self.head_dim)
666+
output_4d = torch.sigmoid(og) * output_4d
667+
output = output_4d.reshape(bsz, seqlen, -1)
668+
if gate is not None:
669+
output = output * torch.sigmoid(gate)
670+
return output
671+
647672

648673
def _l2norm(x: torch.Tensor, dim: int = -1, eps: float = 1e-6) -> torch.Tensor:
649674
inv_norm = torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps)

examples/models/llama/llama_transformer.py

Lines changed: 58 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
# Please refer to README.md in the same folder for more information.
99

10+
import math
1011
from typing import Any, Dict, Optional, Tuple, Union
1112

1213
import torch
@@ -20,7 +21,7 @@
2021
)
2122
from executorch.examples.models.llama.feed_forward import FeedForward, LoRAFeedForward
2223
from executorch.examples.models.llama.model_args import ModelArgs
23-
from executorch.examples.models.llama.norm import RMSNorm
24+
from executorch.examples.models.llama.norm import RMSNorm, RMSNormWithInputScale, ScalelessRMSNorm
2425
from executorch.examples.models.llama.rope import Rope
2526
from torch import nn
2627

@@ -51,6 +52,24 @@ def _is_kv_shared_layer(
5152
return layer_idx >= first_shared and first_shared > 0
5253

5354

55+
class NormPreservingResidualConnection(nn.Module):
56+
def __init__(self, dim: int, init_scale: float, temperature: float = 0.3, eps: float = 1e-3):
57+
super().__init__()
58+
self.eps = eps
59+
self.temperature = temperature
60+
p = max(0.0 + eps, min(1.0 - eps, init_scale))
61+
init_param = math.log(p / (1.0 - p)) * temperature
62+
self.gate = nn.Parameter(torch.full((dim,), init_param))
63+
64+
def forward(self, stream: torch.Tensor, branch: torch.Tensor) -> torch.Tensor:
65+
dtype = stream.dtype
66+
w = self.gate.view(*([1] * (stream.ndim - 1)), -1).float()
67+
beta = torch.sigmoid(w / self.temperature)
68+
alpha_sq = torch.sigmoid(-w / self.temperature) * (1.0 + beta)
69+
alpha = torch.sqrt(torch.clamp(alpha_sq, min=self.eps))
70+
return (alpha * stream.float() + beta * branch.float()).to(dtype)
71+
72+
5473
class ConditionalFeedForward(nn.Module):
5574
def __init__(self, args: ModelArgs):
5675
super().__init__()
@@ -99,7 +118,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
99118

100119
class TransformerBlock(nn.Module):
101120
def __init__(
102-
self, args: ModelArgs, attention: Attention, mlp_type: str = "default"
121+
self, args: ModelArgs, attention: Attention, mlp_type: str = "default",
122+
layer_id: int = 0,
103123
):
104124
"""
105125
Transformer block with support for pre-norm and post-norm.
@@ -110,6 +130,7 @@ def __init__(
110130
the attention type is registered in the ATTENTION_REGISTRY.
111131
mlp_type (str): MLP type for this layer. "default" for standard
112132
FFN, "skip" for no FFN block.
133+
layer_id (int): layer index, used for residual gate initialization.
113134
"""
114135
super().__init__()
115136
self.use_kv_cache = args.use_kv_cache
@@ -118,6 +139,7 @@ def __init__(
118139
self.head_dim = args.head_dim
119140
self.attention = attention
120141
self.mlp_type = mlp_type.lower()
142+
self.use_residual_gate = args.use_residual_gate
121143

122144
assert (
123145
args.hidden_dim is not None
@@ -150,6 +172,16 @@ def __init__(
150172
add_unit_offset=args.rms_norm_add_unit_offset,
151173
)
152174

175+
if args.use_residual_gate:
176+
attn_init = 1.0 / (2 * layer_id + 1) if layer_id > 0 else 0.5
177+
ffn_init = 1.0 / (2 * layer_id + 2)
178+
self.add_attn = NormPreservingResidualConnection(dim=args.dim, init_scale=attn_init)
179+
self.add_ffn = NormPreservingResidualConnection(dim=args.dim, init_scale=ffn_init)
180+
self.post_attn_norm = ScalelessRMSNorm(args.dim, eps=args.norm_eps)
181+
182+
if args.use_ffn_learnable_scales and self.mlp_type != "skip":
183+
self.post_ffn_norm = RMSNormWithInputScale(args.dim, eps=args.norm_eps)
184+
153185
@classmethod
154186
def from_type(cls, layer_id, args, rope) -> "TransformerBlock":
155187
"""
@@ -169,21 +201,38 @@ def from_type(cls, layer_id, args, rope) -> "TransformerBlock":
169201
mlp_type = args.mlp_type[layer_id]
170202
cls = ATTENTION_REGISTRY[args.attention_type]
171203
attention = cls(args, layer_id, rope, **args.attention_kwargs)
172-
return TransformerBlock(args, attention, mlp_type=mlp_type)
204+
return TransformerBlock(args, attention, mlp_type=mlp_type, layer_id=layer_id)
173205

174206
def forward(self, x, freqs_cos, freqs_sin, attn_options: ForwardOptions): # x: 1xN
175207
h, attn_options_update = self.attention(
176208
self.attention_norm(x), freqs_cos, freqs_sin, **attn_options
177209
)
178210
if not isinstance(self.attention, AttentionSkip):
179-
h = x + h
211+
if self.use_residual_gate:
212+
if hasattr(self, "post_attn_norm"):
213+
h = self.post_attn_norm(h)
214+
h = self.add_attn(stream=x, branch=h)
215+
else:
216+
h = x + h
180217

181218
if self.mlp_type == "skip":
182219
out = h
183220
elif hasattr(self, "block_sparse_moe"):
184-
out = h + self.block_sparse_moe(self.ffn_norm(h))
221+
ffn_out = self.block_sparse_moe(self.ffn_norm(h))
222+
if hasattr(self, "post_ffn_norm"):
223+
ffn_out = self.post_ffn_norm(ffn_out)
224+
if self.use_residual_gate:
225+
out = self.add_ffn(stream=h, branch=ffn_out)
226+
else:
227+
out = h + ffn_out
185228
else:
186-
out = h + self.feed_forward(self.ffn_norm(h))
229+
ffn_out = self.feed_forward(self.ffn_norm(h))
230+
if hasattr(self, "post_ffn_norm"):
231+
ffn_out = self.post_ffn_norm(ffn_out)
232+
if self.use_residual_gate:
233+
out = self.add_ffn(stream=h, branch=ffn_out)
234+
else:
235+
out = h + ffn_out
187236
return out, attn_options_update
188237

189238

@@ -371,7 +420,7 @@ def construct_transformer(model_args: ModelArgs) -> Transformer:
371420
and model_args.layer_types[layer_id] == "skip_attention"
372421
):
373422
attention = AttentionSkip()
374-
transformer_block = TransformerBlock(model_args, attention)
423+
transformer_block = TransformerBlock(model_args, attention, layer_id=layer_id)
375424
layers.append(transformer_block)
376425
elif (
377426
model_args.layer_types
@@ -386,13 +435,13 @@ def construct_transformer(model_args: ModelArgs) -> Transformer:
386435
attention = linear_cls(
387436
model_args, layer_id, rope, **model_args.attention_kwargs
388437
)
389-
transformer_block = TransformerBlock(model_args, attention)
438+
transformer_block = TransformerBlock(model_args, attention, layer_id=layer_id)
390439
layers.append(transformer_block)
391440
else:
392441
attention = cls(
393442
model_args, layer_id, rope, **model_args.attention_kwargs
394443
) # pyre-ignore[45]
395-
transformer_block = TransformerBlock(model_args, attention)
444+
transformer_block = TransformerBlock(model_args, attention, layer_id=layer_id)
396445
layers.append(transformer_block)
397446

398447
return Transformer(model_args, layers, rope)

examples/models/llama/model_args.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ class ModelArgs:
102102
apply_output: bool = True # Use output layer (unembedding) inside the transformer
103103
use_qk_norm: bool = False # apply normalization to q and k in the attention
104104
qk_norm_before_rope: bool = False # when to apply qk norm
105+
qk_norm_affine: bool = True # whether QK norm has learnable weight (False = scaleless)
105106
residual_multiplier: Optional[float] = (
106107
None # Scaling factor applied to the residual hidden states
107108
)
@@ -162,6 +163,15 @@ class ModelArgs:
162163
final_logit_softcapping: Optional[float] = None
163164
attn_logit_softcapping: Optional[float] = None
164165

166+
# rlformers forward-pass features for on-device model parity
167+
normalize_tok_embeddings: bool = False
168+
scale_query_by: float = 1.0
169+
use_attn_o_gate: bool = False
170+
use_attn_o_norm: bool = False
171+
use_residual_gate: bool = False
172+
use_ffn_learnable_scales: bool = False
173+
output_soft_cap_temp: Optional[float] = None
174+
165175
def __post_init__(self): # noqa: C901
166176
if self.n_kv_heads is None:
167177
self.n_kv_heads = self.n_heads

examples/models/llama/norm.py

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -32,33 +32,36 @@ def __init__(self, dim: int, eps: float = 1e-6, add_unit_offset: bool = False):
3232
self.weight = nn.Parameter(torch.ones(dim))
3333

3434
def _norm(self, x):
35-
"""
36-
Apply the RMSNorm normalization to the input tensor.
35+
return x * torch.rsqrt((x * x).mean(-1, keepdim=True) + self.eps)
3736

38-
Args:
39-
x (torch.Tensor): The input tensor.
37+
def forward(self, x):
38+
output = self._norm(x.float()).type_as(x)
39+
if self.add_unit_offset:
40+
return output * (1.0 + self.weight.float()).type_as(x)
41+
return output * self.weight.type_as(x)
4042

41-
Returns:
42-
torch.Tensor: The normalized tensor.
4343

44-
"""
45-
return x * torch.rsqrt((x * x).mean(-1, keepdim=True) + self.eps)
44+
class ScalelessRMSNorm(torch.nn.Module):
45+
def __init__(self, dim: int, eps: float = 1e-6):
46+
super().__init__()
47+
self.dim = dim
48+
self.eps = eps
4649

4750
def forward(self, x):
48-
"""
49-
Forward pass through the RMSNorm layer.
51+
x_float = x.float()
52+
return (x_float * torch.rsqrt((x_float * x_float).mean(-1, keepdim=True) + self.eps)).type_as(x)
5053

51-
Args:
52-
x (torch.Tensor): The input tensor.
5354

54-
Returns:
55-
torch.Tensor: The output tensor after applying RMSNorm.
55+
class RMSNormWithInputScale(torch.nn.Module):
56+
def __init__(self, dim: int, eps: float = 1e-5):
57+
super().__init__()
58+
self.eps = eps
59+
self.dim = dim
60+
self.weight = torch.nn.Parameter(torch.ones(dim))
5661

57-
"""
58-
output = self._norm(x.float()).type_as(x)
59-
if self.add_unit_offset:
60-
return output * (1.0 + self.weight.float()).type_as(x)
61-
return output * self.weight.type_as(x)
62+
def forward(self, x):
63+
scaled = (self.weight * x).float()
64+
return (scaled * torch.rsqrt((scaled * scaled).mean(-1, keepdim=True) + self.eps)).type_as(x)
6265

6366

6467
class RMSNormGated(nn.Module):

0 commit comments

Comments
 (0)