Skip to content

Commit 5c8ace1

Browse files
ifed-ucsdfacebook-github-bot
authored andcommitted
Add rlformers forward-pass features to ExecuTorch backbone for on-device export parity
Summary: The 730M dense model checkpoint uses several features that the ExecuTorch XNNPACK export path did not implement. Without these, the exported model produces numerically incorrect output. This diff adds support for 8 missing features: 1. `normalize_tok_embeddings` — scaleless RMSNorm after embedding lookup 2. `qk_norm_before_rope` — conversion from GenAI args (attention code already supported it) 3. `scale_query_by` — custom scalar multiplier on Q after QK norm 4. `use_attn_o_gate` — sigmoid gate on attention output using a learned linear projection of the layer input 5. `use_attn_o_norm` — scaleless per-head RMSNorm on attention output (applied before o_gate) 6. `use_residual_gate` — NormPreservingResidualConnection with learned per-dim gates for both attention and FFN residual connections 7. `use_ffn_learnable_scales` — RMSNormWithInputScale replacing standard post-FFN norm, computing `rms_norm(gamma * x)` instead of `gamma * rms_norm(x)` 8. `output_soft_cap_temp` — `tanh(logits/temp) * temp` soft capping on output logits All features are off by default (backward compatible). They activate when the corresponding fields are set in the checkpoint's params.json and propagated through model_args_conversion. Weight key mappings added for: `attention.og.weight`, `add_attn.gate`, `add_ffn.gate`, `post_ffn_norm.weight`. Reviewed By: chinnadhurai, digantdesai Differential Revision: D102030169
1 parent 321c029 commit 5c8ace1

5 files changed

Lines changed: 201 additions & 53 deletions

File tree

examples/models/llama/attention.py

Lines changed: 49 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,11 @@
77
import torch.nn.functional as F
88
from executorch.examples.models.llama.lora import LoRALinear
99
from executorch.examples.models.llama.model_args import ModelArgs
10-
from executorch.examples.models.llama.norm import RMSNorm, RMSNormGated
10+
from executorch.examples.models.llama.norm import (
11+
RMSNorm,
12+
RMSNormGated,
13+
ScalelessRMSNorm,
14+
)
1115
from executorch.examples.models.llama.rope import Rope
1216

1317

@@ -375,6 +379,9 @@ def __init__(
375379
self.qk_norm_before_rope = args.qk_norm_before_rope
376380
self.use_q_gate = args.use_q_gate
377381
self.enable_dynamic_shape = args.enable_dynamic_shape
382+
self.scale_query_by = args.scale_query_by
383+
self.use_attn_o_gate = args.use_attn_o_gate
384+
self.use_attn_o_norm = args.use_attn_o_norm
378385
q_out_dim = self.n_heads * self.head_dim * (2 if self.use_q_gate else 1)
379386

380387
# YOCO: Determine if this is a KV shared layer (receives shared KV from donor).
@@ -417,17 +424,26 @@ def __init__(
417424
def _init_norms(self, args: ModelArgs) -> None:
418425
"""Initialize QK normalization layers."""
419426
if self.use_qk_norm:
420-
self.q_norm_fn = RMSNorm(
421-
self.head_dim,
422-
eps=args.norm_eps,
423-
add_unit_offset=args.rms_norm_add_unit_offset,
424-
)
425-
if self.has_kv_weights:
426-
self.k_norm_fn = RMSNorm(
427+
if args.qk_norm_affine:
428+
self.q_norm_fn = RMSNorm(
427429
self.head_dim,
428430
eps=args.norm_eps,
429431
add_unit_offset=args.rms_norm_add_unit_offset,
430432
)
433+
if self.has_kv_weights:
434+
self.k_norm_fn = RMSNorm(
435+
self.head_dim,
436+
eps=args.norm_eps,
437+
add_unit_offset=args.rms_norm_add_unit_offset,
438+
)
439+
else:
440+
self.q_norm_fn = ScalelessRMSNorm(self.head_dim, eps=args.norm_eps)
441+
if self.has_kv_weights:
442+
self.k_norm_fn = ScalelessRMSNorm(self.head_dim, eps=args.norm_eps)
443+
if self.use_attn_o_norm:
444+
self.o_norm = ScalelessRMSNorm(self.head_dim, eps=args.norm_eps)
445+
if self.use_attn_o_gate:
446+
self.og = nn.Linear(args.dim, self.n_heads * self.head_dim, bias=False)
431447

432448
def _init_projections(self, args: ModelArgs, q_out_dim: int) -> None:
433449
"""Initialize Q/K/V/O projection layers."""
@@ -478,13 +494,17 @@ def _prepare_qkv_shared(
478494

479495
if self.use_qk_norm and self.qk_norm_before_rope:
480496
q = self.q_norm_fn(q)
497+
if self.scale_query_by != 1.0:
498+
q = q * self.scale_query_by
481499

482500
# Apply RoPE to Q only (K already has RoPE from donor layer)
483501
q, _ = self.rope.forward(q, q, freqs_cos, freqs_sin)
484502
q = q.transpose(1, 2)
485503

486504
if self.use_qk_norm and not self.qk_norm_before_rope:
487505
q = self.q_norm_fn(q)
506+
if self.scale_query_by != 1.0:
507+
q = q * self.scale_query_by
488508

489509
return q, k, v
490510

@@ -508,6 +528,8 @@ def _prepare_qkv(
508528

509529
if self.use_qk_norm and self.qk_norm_before_rope:
510530
q = self.q_norm_fn(q)
531+
if self.scale_query_by != 1.0:
532+
q = q * self.scale_query_by
511533
k = self.k_norm_fn(k)
512534

513535
q, k = self.rope.forward(q, k, freqs_cos, freqs_sin)
@@ -518,6 +540,8 @@ def _prepare_qkv(
518540

519541
if self.use_qk_norm and not self.qk_norm_before_rope:
520542
q = self.q_norm_fn(q)
543+
if self.scale_query_by != 1.0:
544+
q = q * self.scale_query_by
521545
k = self.k_norm_fn(k)
522546

523547
return q, k, v
@@ -582,8 +606,7 @@ def forward(
582606
)
583607

584608
output = self.SDPA(input_pos, q, k, v, bsz, seqlen, attn_mask)
585-
if gate is not None:
586-
output = output * torch.sigmoid(gate)
609+
output = self._apply_output_transforms(output, x, gate, bsz, seqlen)
587610

588611
if shared_kv is None and self.num_kv_shared_layers > 0:
589612
update = {"kv_to_share": (k, v)}
@@ -602,13 +625,27 @@ def forward(
602625
output = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
603626

604627
output = output.transpose(1, 2).reshape(bsz, seqlen, -1)
605-
if gate is not None:
606-
output = output * torch.sigmoid(gate)
628+
output = self._apply_output_transforms(output, x, gate, bsz, seqlen)
607629

608630
output = self.wo(output)
609631

610632
return output, None
611633

634+
def _apply_output_transforms(
635+
self, output: torch.Tensor, x: torch.Tensor, gate, bsz: int, seqlen: int
636+
) -> torch.Tensor:
637+
if self.use_attn_o_norm or self.use_attn_o_gate:
638+
output_4d = output.view(bsz, seqlen, self.n_local_heads, self.head_dim)
639+
if self.use_attn_o_norm:
640+
output_4d = self.o_norm(output_4d)
641+
if self.use_attn_o_gate:
642+
og = self.og(x).view(bsz, seqlen, self.n_local_heads, self.head_dim)
643+
output_4d = torch.sigmoid(og) * output_4d
644+
output = output_4d.reshape(bsz, seqlen, -1)
645+
if gate is not None:
646+
output = output * torch.sigmoid(gate)
647+
return output
648+
612649

613650
def _l2norm(x: torch.Tensor, dim: int = -1, eps: float = 1e-6) -> torch.Tensor:
614651
inv_norm = torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps)

examples/models/llama/llama_transformer.py

Lines changed: 77 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
# Please refer to README.md in the same folder for more information.
99

10+
import math
1011
from typing import Any, Dict, Optional, Tuple, Union
1112

1213
import torch
@@ -20,7 +21,11 @@
2021
)
2122
from executorch.examples.models.llama.feed_forward import FeedForward, LoRAFeedForward
2223
from executorch.examples.models.llama.model_args import ModelArgs
23-
from executorch.examples.models.llama.norm import RMSNorm
24+
from executorch.examples.models.llama.norm import (
25+
RMSNorm,
26+
RMSNormWithInputScale,
27+
ScalelessRMSNorm,
28+
)
2429
from executorch.examples.models.llama.rope import Rope
2530
from torch import nn
2631

@@ -51,6 +56,26 @@ def _is_kv_shared_layer(
5156
return layer_idx >= first_shared and first_shared > 0
5257

5358

59+
class NormPreservingResidualConnection(nn.Module):
60+
def __init__(
61+
self, dim: int, init_scale: float, temperature: float = 0.3, eps: float = 1e-3
62+
):
63+
super().__init__()
64+
self.eps = eps
65+
self.temperature = temperature
66+
p = max(0.0 + eps, min(1.0 - eps, init_scale))
67+
init_param = math.log(p / (1.0 - p)) * temperature
68+
self.gate = nn.Parameter(torch.full((dim,), init_param))
69+
70+
def forward(self, stream: torch.Tensor, branch: torch.Tensor) -> torch.Tensor:
71+
dtype = stream.dtype
72+
w = self.gate.view(*([1] * (stream.ndim - 1)), -1).float()
73+
beta = torch.sigmoid(w / self.temperature)
74+
alpha_sq = torch.sigmoid(-w / self.temperature) * (1.0 + beta)
75+
alpha = torch.sqrt(torch.clamp(alpha_sq, min=self.eps))
76+
return (alpha * stream.float() + beta * branch.float()).to(dtype)
77+
78+
5479
class ConditionalFeedForward(nn.Module):
5580
def __init__(self, args: ModelArgs):
5681
super().__init__()
@@ -99,7 +124,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
99124

100125
class TransformerBlock(nn.Module):
101126
def __init__(
102-
self, args: ModelArgs, attention: Attention, mlp_type: str = "default"
127+
self,
128+
args: ModelArgs,
129+
attention: Attention,
130+
mlp_type: str = "default",
131+
layer_id: int = 0,
103132
):
104133
"""
105134
Transformer block with support for pre-norm and post-norm.
@@ -110,6 +139,7 @@ def __init__(
110139
the attention type is registered in the ATTENTION_REGISTRY.
111140
mlp_type (str): MLP type for this layer. "default" for standard
112141
FFN, "skip" for no FFN block.
142+
layer_id (int): layer index, used for residual gate initialization.
113143
"""
114144
super().__init__()
115145
self.use_kv_cache = args.use_kv_cache
@@ -118,6 +148,7 @@ def __init__(
118148
self.head_dim = args.head_dim
119149
self.attention = attention
120150
self.mlp_type = mlp_type.lower()
151+
self.use_residual_gate = args.use_residual_gate
121152

122153
assert (
123154
args.hidden_dim is not None
@@ -150,6 +181,20 @@ def __init__(
150181
add_unit_offset=args.rms_norm_add_unit_offset,
151182
)
152183

184+
if args.use_residual_gate:
185+
attn_init = 1.0 / (2 * layer_id + 1) if layer_id > 0 else 0.5
186+
ffn_init = 1.0 / (2 * layer_id + 2)
187+
self.add_attn = NormPreservingResidualConnection(
188+
dim=args.dim, init_scale=attn_init
189+
)
190+
self.add_ffn = NormPreservingResidualConnection(
191+
dim=args.dim, init_scale=ffn_init
192+
)
193+
self.post_attn_norm = ScalelessRMSNorm(args.dim, eps=args.norm_eps)
194+
195+
if args.use_ffn_learnable_scales and self.mlp_type != "skip":
196+
self.post_ffn_norm = RMSNormWithInputScale(args.dim, eps=args.norm_eps)
197+
153198
@classmethod
154199
def from_type(cls, layer_id, args, rope) -> "TransformerBlock":
155200
"""
@@ -169,21 +214,38 @@ def from_type(cls, layer_id, args, rope) -> "TransformerBlock":
169214
mlp_type = args.mlp_type[layer_id]
170215
cls = ATTENTION_REGISTRY[args.attention_type]
171216
attention = cls(args, layer_id, rope, **args.attention_kwargs)
172-
return TransformerBlock(args, attention, mlp_type=mlp_type)
217+
return TransformerBlock(args, attention, mlp_type=mlp_type, layer_id=layer_id)
173218

174219
def forward(self, x, freqs_cos, freqs_sin, attn_options: ForwardOptions): # x: 1xN
175220
h, attn_options_update = self.attention(
176221
self.attention_norm(x), freqs_cos, freqs_sin, **attn_options
177222
)
178223
if not isinstance(self.attention, AttentionSkip):
179-
h = x + h
224+
if self.use_residual_gate:
225+
if hasattr(self, "post_attn_norm"):
226+
h = self.post_attn_norm(h)
227+
h = self.add_attn(stream=x, branch=h)
228+
else:
229+
h = x + h
180230

181231
if self.mlp_type == "skip":
182232
out = h
183233
elif hasattr(self, "block_sparse_moe"):
184-
out = h + self.block_sparse_moe(self.ffn_norm(h))
234+
ffn_out = self.block_sparse_moe(self.ffn_norm(h))
235+
if hasattr(self, "post_ffn_norm"):
236+
ffn_out = self.post_ffn_norm(ffn_out)
237+
if self.use_residual_gate:
238+
out = self.add_ffn(stream=h, branch=ffn_out)
239+
else:
240+
out = h + ffn_out
185241
else:
186-
out = h + self.feed_forward(self.ffn_norm(h))
242+
ffn_out = self.feed_forward(self.ffn_norm(h))
243+
if hasattr(self, "post_ffn_norm"):
244+
ffn_out = self.post_ffn_norm(ffn_out)
245+
if self.use_residual_gate:
246+
out = self.add_ffn(stream=h, branch=ffn_out)
247+
else:
248+
out = h + ffn_out
187249
return out, attn_options_update
188250

189251

@@ -371,7 +433,9 @@ def construct_transformer(model_args: ModelArgs) -> Transformer:
371433
and model_args.layer_types[layer_id] == "skip_attention"
372434
):
373435
attention = AttentionSkip()
374-
transformer_block = TransformerBlock(model_args, attention)
436+
transformer_block = TransformerBlock(
437+
model_args, attention, layer_id=layer_id
438+
)
375439
layers.append(transformer_block)
376440
elif (
377441
model_args.layer_types
@@ -386,13 +450,17 @@ def construct_transformer(model_args: ModelArgs) -> Transformer:
386450
attention = linear_cls(
387451
model_args, layer_id, rope, **model_args.attention_kwargs
388452
)
389-
transformer_block = TransformerBlock(model_args, attention)
453+
transformer_block = TransformerBlock(
454+
model_args, attention, layer_id=layer_id
455+
)
390456
layers.append(transformer_block)
391457
else:
392458
attention = cls(
393459
model_args, layer_id, rope, **model_args.attention_kwargs
394460
) # pyre-ignore[45]
395-
transformer_block = TransformerBlock(model_args, attention)
461+
transformer_block = TransformerBlock(
462+
model_args, attention, layer_id=layer_id
463+
)
396464
layers.append(transformer_block)
397465

398466
return Transformer(model_args, layers, rope)

examples/models/llama/model_args.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,9 @@ class ModelArgs:
102102
apply_output: bool = True # Use output layer (unembedding) inside the transformer
103103
use_qk_norm: bool = False # apply normalization to q and k in the attention
104104
qk_norm_before_rope: bool = False # when to apply qk norm
105+
qk_norm_affine: bool = (
106+
True # whether QK norm has learnable weight (False = scaleless)
107+
)
105108
residual_multiplier: Optional[float] = (
106109
None # Scaling factor applied to the residual hidden states
107110
)
@@ -162,6 +165,15 @@ class ModelArgs:
162165
final_logit_softcapping: Optional[float] = None
163166
attn_logit_softcapping: Optional[float] = None
164167

168+
# rlformers forward-pass features for on-device model parity
169+
normalize_tok_embeddings: bool = False
170+
scale_query_by: float = 1.0
171+
use_attn_o_gate: bool = False
172+
use_attn_o_norm: bool = False
173+
use_residual_gate: bool = False
174+
use_ffn_learnable_scales: bool = False
175+
output_soft_cap_temp: Optional[float] = None
176+
165177
def __post_init__(self): # noqa: C901
166178
if self.n_kv_heads is None:
167179
self.n_kv_heads = self.n_heads

examples/models/llama/norm.py

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -32,33 +32,38 @@ def __init__(self, dim: int, eps: float = 1e-6, add_unit_offset: bool = False):
3232
self.weight = nn.Parameter(torch.ones(dim))
3333

3434
def _norm(self, x):
35-
"""
36-
Apply the RMSNorm normalization to the input tensor.
35+
return x * torch.rsqrt((x * x).mean(-1, keepdim=True) + self.eps)
3736

38-
Args:
39-
x (torch.Tensor): The input tensor.
37+
def forward(self, x):
38+
output = self._norm(x.float()).type_as(x)
39+
if self.add_unit_offset:
40+
return output * (1.0 + self.weight.float()).type_as(x)
41+
return output * self.weight.type_as(x)
4042

41-
Returns:
42-
torch.Tensor: The normalized tensor.
4343

44-
"""
45-
return x * torch.rsqrt((x * x).mean(-1, keepdim=True) + self.eps)
44+
class ScalelessRMSNorm(torch.nn.Module):
45+
def __init__(self, dim: int, eps: float = 1e-6):
46+
super().__init__()
47+
self.dim = dim
48+
self.eps = eps
4649

4750
def forward(self, x):
48-
"""
49-
Forward pass through the RMSNorm layer.
51+
x_float = x.float()
52+
return (
53+
x_float * torch.rsqrt((x_float * x_float).mean(-1, keepdim=True) + self.eps)
54+
).type_as(x)
5055

51-
Args:
52-
x (torch.Tensor): The input tensor.
5356

54-
Returns:
55-
torch.Tensor: The output tensor after applying RMSNorm.
57+
class RMSNormWithInputScale(torch.nn.Module):
58+
def __init__(self, dim: int, eps: float = 1e-5):
59+
super().__init__()
60+
self.eps = eps
61+
self.dim = dim
62+
self.weight = torch.nn.Parameter(torch.ones(dim))
5663

57-
"""
58-
output = self._norm(x.float()).type_as(x)
59-
if self.add_unit_offset:
60-
return output * (1.0 + self.weight.float()).type_as(x)
61-
return output * self.weight.type_as(x)
64+
def forward(self, x):
65+
scaled = self.weight * x
66+
return F.rms_norm(scaled, (self.dim,), None, self.eps)
6267

6368

6469
class RMSNormGated(nn.Module):

0 commit comments

Comments
 (0)