Skip to content

Commit 7feccab

Browse files
ifed-ucsdfacebook-github-bot
authored andcommitted
Add rlformers forward-pass features to ExecuTorch backbone for on-device export parity
Summary: The 730M dense model checkpoint (16node_IFT_avoDistilled730m) uses several rlformers features that the ExecuTorch XNNPACK export path did not implement. Without these, the exported model produces numerically incorrect output. This diff adds support for 8 missing features: 1. `normalize_tok_embeddings` — scaleless RMSNorm after embedding lookup 2. `qk_norm_before_rope` — conversion from GenAI args (attention code already supported it) 3. `scale_query_by` — custom scalar multiplier on Q after QK norm 4. `use_attn_o_gate` — sigmoid gate on attention output using a learned linear projection of the layer input 5. `use_attn_o_norm` — scaleless per-head RMSNorm on attention output (applied before o_gate) 6. `use_residual_gate` — NormPreservingResidualConnection with learned per-dim gates for both attention and FFN residual connections 7. `use_ffn_learnable_scales` — RMSNormWithInputScale replacing standard post-FFN norm, computing `rms_norm(gamma * x)` instead of `gamma * rms_norm(x)` 8. `output_soft_cap_temp` — `tanh(logits/temp) * temp` soft capping on output logits All features are off by default (backward compatible). They activate when the corresponding fields are set in the checkpoint's params.json and propagated through model_args_conversion. Weight key mappings added for: `attention.og.weight`, `add_attn.gate`, `add_ffn.gate`, `post_ffn_norm.weight`. Differential Revision: D102030169
1 parent 7b5dcc1 commit 7feccab

4 files changed

Lines changed: 129 additions & 44 deletions

File tree

examples/models/llama/attention.py

Lines changed: 41 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import torch.nn.functional as F
88
from executorch.examples.models.llama.lora import LoRALinear
99
from executorch.examples.models.llama.model_args import ModelArgs
10-
from executorch.examples.models.llama.norm import RMSNorm, RMSNormGated
10+
from executorch.examples.models.llama.norm import RMSNorm, RMSNormGated, ScalelessRMSNorm
1111
from executorch.examples.models.llama.rope import Rope
1212

1313

@@ -375,6 +375,9 @@ def __init__(
375375
self.qk_norm_before_rope = args.qk_norm_before_rope
376376
self.use_q_gate = args.use_q_gate
377377
self.enable_dynamic_shape = args.enable_dynamic_shape
378+
self.scale_query_by = args.scale_query_by
379+
self.use_attn_o_gate = args.use_attn_o_gate
380+
self.use_attn_o_norm = args.use_attn_o_norm
378381
q_out_dim = self.n_heads * self.head_dim * (2 if self.use_q_gate else 1)
379382

380383
# YOCO: Determine if this is a KV shared layer (receives shared KV from donor).
@@ -417,17 +420,26 @@ def __init__(
417420
def _init_norms(self, args: ModelArgs) -> None:
418421
"""Initialize QK normalization layers."""
419422
if self.use_qk_norm:
420-
self.q_norm_fn = RMSNorm(
421-
self.head_dim,
422-
eps=args.norm_eps,
423-
add_unit_offset=args.rms_norm_add_unit_offset,
424-
)
425-
if self.has_kv_weights:
426-
self.k_norm_fn = RMSNorm(
423+
if args.qk_norm_affine:
424+
self.q_norm_fn = RMSNorm(
427425
self.head_dim,
428426
eps=args.norm_eps,
429427
add_unit_offset=args.rms_norm_add_unit_offset,
430428
)
429+
if self.has_kv_weights:
430+
self.k_norm_fn = RMSNorm(
431+
self.head_dim,
432+
eps=args.norm_eps,
433+
add_unit_offset=args.rms_norm_add_unit_offset,
434+
)
435+
else:
436+
self.q_norm_fn = RMSNorm(self.head_dim, eps=args.norm_eps)
437+
if self.has_kv_weights:
438+
self.k_norm_fn = RMSNorm(self.head_dim, eps=args.norm_eps)
439+
if self.use_attn_o_norm:
440+
self.o_norm = ScalelessRMSNorm(self.head_dim, eps=args.norm_eps)
441+
if self.use_attn_o_gate:
442+
self.og = nn.Linear(args.dim, self.n_heads * self.head_dim, bias=False)
431443

432444
def _init_projections(self, args: ModelArgs, q_out_dim: int) -> None:
433445
"""Initialize Q/K/V/O projection layers."""
@@ -477,14 +489,14 @@ def _prepare_qkv_shared(
477489
k, v = shared_kv
478490

479491
if self.use_qk_norm and self.qk_norm_before_rope:
480-
q = self.q_norm_fn(q)
492+
q = self.q_norm_fn(q) * self.scale_query_by
481493

482494
# Apply RoPE to Q only (K already has RoPE from donor layer)
483495
q, _ = self.rope.forward(q, q, freqs_cos, freqs_sin)
484496
q = q.transpose(1, 2)
485497

486498
if self.use_qk_norm and not self.qk_norm_before_rope:
487-
q = self.q_norm_fn(q)
499+
q = self.q_norm_fn(q) * self.scale_query_by
488500

489501
return q, k, v
490502

@@ -507,7 +519,7 @@ def _prepare_qkv(
507519
v = v.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
508520

509521
if self.use_qk_norm and self.qk_norm_before_rope:
510-
q = self.q_norm_fn(q)
522+
q = self.q_norm_fn(q) * self.scale_query_by
511523
k = self.k_norm_fn(k)
512524

513525
q, k = self.rope.forward(q, k, freqs_cos, freqs_sin)
@@ -517,7 +529,7 @@ def _prepare_qkv(
517529
v = v.transpose(1, 2)
518530

519531
if self.use_qk_norm and not self.qk_norm_before_rope:
520-
q = self.q_norm_fn(q)
532+
q = self.q_norm_fn(q) * self.scale_query_by
521533
k = self.k_norm_fn(k)
522534

523535
return q, k, v
@@ -582,8 +594,7 @@ def forward(
582594
)
583595

584596
output = self.SDPA(input_pos, q, k, v, bsz, seqlen, attn_mask)
585-
if gate is not None:
586-
output = output * torch.sigmoid(gate)
597+
output = self._apply_output_transforms(output, x, gate, bsz, seqlen)
587598

588599
if shared_kv is None and self.num_kv_shared_layers > 0:
589600
update = {"kv_to_share": (k, v)}
@@ -602,13 +613,27 @@ def forward(
602613
output = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
603614

604615
output = output.transpose(1, 2).reshape(bsz, seqlen, -1)
605-
if gate is not None:
606-
output = output * torch.sigmoid(gate)
616+
output = self._apply_output_transforms(output, x, gate, bsz, seqlen)
607617

608618
output = self.wo(output)
609619

610620
return output, None
611621

622+
def _apply_output_transforms(
623+
self, output: torch.Tensor, x: torch.Tensor, gate, bsz: int, seqlen: int
624+
) -> torch.Tensor:
625+
if self.use_attn_o_norm or self.use_attn_o_gate:
626+
output_4d = output.view(bsz, seqlen, self.n_local_heads, self.head_dim)
627+
if self.use_attn_o_norm:
628+
output_4d = self.o_norm(output_4d)
629+
if self.use_attn_o_gate:
630+
og = self.og(x).view(bsz, seqlen, self.n_local_heads, self.head_dim)
631+
output_4d = torch.sigmoid(og) * output_4d
632+
output = output_4d.reshape(bsz, seqlen, -1)
633+
if gate is not None:
634+
output = output * torch.sigmoid(gate)
635+
return output
636+
612637

613638
def _l2norm(x: torch.Tensor, dim: int = -1, eps: float = 1e-6) -> torch.Tensor:
614639
inv_norm = torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps)

examples/models/llama/llama_transformer.py

Lines changed: 57 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
)
2121
from executorch.examples.models.llama.feed_forward import FeedForward, LoRAFeedForward
2222
from executorch.examples.models.llama.model_args import ModelArgs
23-
from executorch.examples.models.llama.norm import RMSNorm
23+
from executorch.examples.models.llama.norm import RMSNorm, RMSNormWithInputScale, ScalelessRMSNorm
2424
from executorch.examples.models.llama.rope import Rope
2525
from torch import nn
2626

@@ -51,6 +51,24 @@ def _is_kv_shared_layer(
5151
return layer_idx >= first_shared and first_shared > 0
5252

5353

54+
class NormPreservingResidualConnection(nn.Module):
55+
def __init__(self, dim: int, init_scale: float, temperature: float = 0.3, eps: float = 1e-3):
56+
super().__init__()
57+
import math
58+
self.eps = eps
59+
self.temperature = temperature
60+
p = max(0.0 + eps, min(1.0 - eps, init_scale))
61+
init_param = math.log(p / (1.0 - p)) * temperature
62+
self.gate = nn.Parameter(torch.full((dim,), init_param))
63+
64+
def forward(self, stream: torch.Tensor, branch: torch.Tensor) -> torch.Tensor:
65+
w = self.gate.view(*([1] * (stream.ndim - 1)), -1)
66+
beta = torch.sigmoid(w / self.temperature)
67+
alpha_sq = torch.sigmoid(-w / self.temperature) * (1.0 + beta)
68+
alpha = torch.sqrt(torch.clamp(alpha_sq, min=self.eps))
69+
return alpha * stream + beta * branch
70+
71+
5472
class ConditionalFeedForward(nn.Module):
5573
def __init__(self, args: ModelArgs):
5674
super().__init__()
@@ -99,7 +117,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
99117

100118
class TransformerBlock(nn.Module):
101119
def __init__(
102-
self, args: ModelArgs, attention: Attention, mlp_type: str = "default"
120+
self, args: ModelArgs, attention: Attention, mlp_type: str = "default",
121+
layer_id: int = 0,
103122
):
104123
"""
105124
Transformer block with support for pre-norm and post-norm.
@@ -110,6 +129,7 @@ def __init__(
110129
the attention type is registered in the ATTENTION_REGISTRY.
111130
mlp_type (str): MLP type for this layer. "default" for standard
112131
FFN, "skip" for no FFN block.
132+
layer_id (int): layer index, used for residual gate initialization.
113133
"""
114134
super().__init__()
115135
self.use_kv_cache = args.use_kv_cache
@@ -118,6 +138,7 @@ def __init__(
118138
self.head_dim = args.head_dim
119139
self.attention = attention
120140
self.mlp_type = mlp_type.lower()
141+
self.use_residual_gate = args.use_residual_gate
121142

122143
assert (
123144
args.hidden_dim is not None
@@ -150,6 +171,16 @@ def __init__(
150171
add_unit_offset=args.rms_norm_add_unit_offset,
151172
)
152173

174+
if args.use_residual_gate:
175+
attn_init = 1.0 / (2 * layer_id + 1) if layer_id > 0 else 0.5
176+
ffn_init = 1.0 / (2 * layer_id + 2)
177+
self.add_attn = NormPreservingResidualConnection(dim=args.dim, init_scale=attn_init)
178+
self.add_ffn = NormPreservingResidualConnection(dim=args.dim, init_scale=ffn_init)
179+
self.post_attn_norm = ScalelessRMSNorm(args.dim, eps=args.norm_eps)
180+
181+
if args.use_ffn_learnable_scales and self.mlp_type != "skip":
182+
self.post_ffn_norm = RMSNormWithInputScale(args.dim, eps=args.norm_eps)
183+
153184
@classmethod
154185
def from_type(cls, layer_id, args, rope) -> "TransformerBlock":
155186
"""
@@ -169,21 +200,38 @@ def from_type(cls, layer_id, args, rope) -> "TransformerBlock":
169200
mlp_type = args.mlp_type[layer_id]
170201
cls = ATTENTION_REGISTRY[args.attention_type]
171202
attention = cls(args, layer_id, rope, **args.attention_kwargs)
172-
return TransformerBlock(args, attention, mlp_type=mlp_type)
203+
return TransformerBlock(args, attention, mlp_type=mlp_type, layer_id=layer_id)
173204

174205
def forward(self, x, freqs_cos, freqs_sin, attn_options: ForwardOptions): # x: 1xN
175206
h, attn_options_update = self.attention(
176207
self.attention_norm(x), freqs_cos, freqs_sin, **attn_options
177208
)
178209
if not isinstance(self.attention, AttentionSkip):
179-
h = x + h
210+
if self.use_residual_gate:
211+
if hasattr(self, "post_attn_norm"):
212+
h = self.post_attn_norm(h)
213+
h = self.add_attn(stream=x, branch=h)
214+
else:
215+
h = x + h
180216

181217
if self.mlp_type == "skip":
182218
out = h
183219
elif hasattr(self, "block_sparse_moe"):
184-
out = h + self.block_sparse_moe(self.ffn_norm(h))
220+
ffn_out = self.block_sparse_moe(self.ffn_norm(h))
221+
if hasattr(self, "post_ffn_norm"):
222+
ffn_out = self.post_ffn_norm(ffn_out)
223+
if self.use_residual_gate:
224+
out = self.add_ffn(stream=h, branch=ffn_out)
225+
else:
226+
out = h + ffn_out
185227
else:
186-
out = h + self.feed_forward(self.ffn_norm(h))
228+
ffn_out = self.feed_forward(self.ffn_norm(h))
229+
if hasattr(self, "post_ffn_norm"):
230+
ffn_out = self.post_ffn_norm(ffn_out)
231+
if self.use_residual_gate:
232+
out = self.add_ffn(stream=h, branch=ffn_out)
233+
else:
234+
out = h + ffn_out
187235
return out, attn_options_update
188236

189237

@@ -371,7 +419,7 @@ def construct_transformer(model_args: ModelArgs) -> Transformer:
371419
and model_args.layer_types[layer_id] == "skip_attention"
372420
):
373421
attention = AttentionSkip()
374-
transformer_block = TransformerBlock(model_args, attention)
422+
transformer_block = TransformerBlock(model_args, attention, layer_id=layer_id)
375423
layers.append(transformer_block)
376424
elif (
377425
model_args.layer_types
@@ -386,13 +434,13 @@ def construct_transformer(model_args: ModelArgs) -> Transformer:
386434
attention = linear_cls(
387435
model_args, layer_id, rope, **model_args.attention_kwargs
388436
)
389-
transformer_block = TransformerBlock(model_args, attention)
437+
transformer_block = TransformerBlock(model_args, attention, layer_id=layer_id)
390438
layers.append(transformer_block)
391439
else:
392440
attention = cls(
393441
model_args, layer_id, rope, **model_args.attention_kwargs
394442
) # pyre-ignore[45]
395-
transformer_block = TransformerBlock(model_args, attention)
443+
transformer_block = TransformerBlock(model_args, attention, layer_id=layer_id)
396444
layers.append(transformer_block)
397445

398446
return Transformer(model_args, layers, rope)

examples/models/llama/model_args.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ class ModelArgs:
102102
apply_output: bool = True # Use output layer (unembedding) inside the transformer
103103
use_qk_norm: bool = False # apply normalization to q and k in the attention
104104
qk_norm_before_rope: bool = False # when to apply qk norm
105+
qk_norm_affine: bool = True # whether QK norm has learnable weight (False = scaleless)
105106
residual_multiplier: Optional[float] = (
106107
None # Scaling factor applied to the residual hidden states
107108
)
@@ -162,6 +163,15 @@ class ModelArgs:
162163
final_logit_softcapping: Optional[float] = None
163164
attn_logit_softcapping: Optional[float] = None
164165

166+
# rlformers forward-pass features for on-device model parity
167+
normalize_tok_embeddings: bool = False
168+
scale_query_by: float = 1.0
169+
use_attn_o_gate: bool = False
170+
use_attn_o_norm: bool = False
171+
use_residual_gate: bool = False
172+
use_ffn_learnable_scales: bool = False
173+
output_soft_cap_temp: Optional[float] = None
174+
165175
def __post_init__(self): # noqa: C901
166176
if self.n_kv_heads is None:
167177
self.n_kv_heads = self.n_heads

examples/models/llama/norm.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -32,33 +32,35 @@ def __init__(self, dim: int, eps: float = 1e-6, add_unit_offset: bool = False):
3232
self.weight = nn.Parameter(torch.ones(dim))
3333

3434
def _norm(self, x):
35-
"""
36-
Apply the RMSNorm normalization to the input tensor.
35+
return x * torch.rsqrt((x * x).mean(-1, keepdim=True) + self.eps)
3736

38-
Args:
39-
x (torch.Tensor): The input tensor.
37+
def forward(self, x):
38+
output = self._norm(x.float()).type_as(x)
39+
if self.add_unit_offset:
40+
return output * (1.0 + self.weight.float()).type_as(x)
41+
return output * self.weight.type_as(x)
4042

41-
Returns:
42-
torch.Tensor: The normalized tensor.
4343

44-
"""
45-
return x * torch.rsqrt((x * x).mean(-1, keepdim=True) + self.eps)
44+
class ScalelessRMSNorm(torch.nn.Module):
45+
def __init__(self, dim: int, eps: float = 1e-6):
46+
super().__init__()
47+
self.dim = dim
48+
self.eps = eps
4649

4750
def forward(self, x):
48-
"""
49-
Forward pass through the RMSNorm layer.
51+
return x * torch.rsqrt((x * x).mean(-1, keepdim=True) + self.eps)
5052

51-
Args:
52-
x (torch.Tensor): The input tensor.
5353

54-
Returns:
55-
torch.Tensor: The output tensor after applying RMSNorm.
54+
class RMSNormWithInputScale(torch.nn.Module):
55+
def __init__(self, dim: int, eps: float = 1e-5):
56+
super().__init__()
57+
self.eps = eps
58+
self.dim = dim
59+
self.weight = torch.nn.Parameter(torch.ones(dim))
5660

57-
"""
58-
output = self._norm(x.float()).type_as(x)
59-
if self.add_unit_offset:
60-
return output * (1.0 + self.weight.float()).type_as(x)
61-
return output * self.weight.type_as(x)
61+
def forward(self, x):
62+
scaled = self.weight * x
63+
return scaled * torch.rsqrt((scaled * scaled).mean(-1, keepdim=True) + self.eps)
6264

6365

6466
class RMSNormGated(nn.Module):

0 commit comments

Comments
 (0)