Skip to content

Commit 8a97ac7

Browse files
authored
Add rlformers forward-pass features to ExecuTorch backbone for on-device export parity (#19096)
Differential Revision: D102030169 Pull Request resolved: #19096
1 parent e4ede92 commit 8a97ac7

5 files changed

Lines changed: 201 additions & 53 deletions

File tree

examples/models/llama/attention.py

Lines changed: 49 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,11 @@
77
import torch.nn.functional as F
88
from executorch.examples.models.llama.lora import LoRALinear
99
from executorch.examples.models.llama.model_args import ModelArgs
10-
from executorch.examples.models.llama.norm import RMSNorm, RMSNormGated
10+
from executorch.examples.models.llama.norm import (
11+
RMSNorm,
12+
RMSNormGated,
13+
ScalelessRMSNorm,
14+
)
1115
from executorch.examples.models.llama.rope import Rope
1216

1317

@@ -375,6 +379,9 @@ def __init__(
375379
self.qk_norm_before_rope = args.qk_norm_before_rope
376380
self.use_q_gate = args.use_q_gate
377381
self.enable_dynamic_shape = args.enable_dynamic_shape
382+
self.scale_query_by = args.scale_query_by
383+
self.use_attn_o_gate = args.use_attn_o_gate
384+
self.use_attn_o_norm = args.use_attn_o_norm
378385
q_out_dim = self.n_heads * self.head_dim * (2 if self.use_q_gate else 1)
379386

380387
# YOCO: Determine if this is a KV shared layer (receives shared KV from donor).
@@ -417,17 +424,26 @@ def __init__(
417424
def _init_norms(self, args: ModelArgs) -> None:
418425
"""Initialize QK normalization layers."""
419426
if self.use_qk_norm:
420-
self.q_norm_fn = RMSNorm(
421-
self.head_dim,
422-
eps=args.norm_eps,
423-
add_unit_offset=args.rms_norm_add_unit_offset,
424-
)
425-
if self.has_kv_weights:
426-
self.k_norm_fn = RMSNorm(
427+
if args.qk_norm_affine:
428+
self.q_norm_fn = RMSNorm(
427429
self.head_dim,
428430
eps=args.norm_eps,
429431
add_unit_offset=args.rms_norm_add_unit_offset,
430432
)
433+
if self.has_kv_weights:
434+
self.k_norm_fn = RMSNorm(
435+
self.head_dim,
436+
eps=args.norm_eps,
437+
add_unit_offset=args.rms_norm_add_unit_offset,
438+
)
439+
else:
440+
self.q_norm_fn = ScalelessRMSNorm(self.head_dim, eps=args.norm_eps)
441+
if self.has_kv_weights:
442+
self.k_norm_fn = ScalelessRMSNorm(self.head_dim, eps=args.norm_eps)
443+
if self.use_attn_o_norm:
444+
self.o_norm = ScalelessRMSNorm(self.head_dim, eps=args.norm_eps)
445+
if self.use_attn_o_gate:
446+
self.og = nn.Linear(args.dim, self.n_heads * self.head_dim, bias=False)
431447

432448
def _init_projections(self, args: ModelArgs, q_out_dim: int) -> None:
433449
"""Initialize Q/K/V/O projection layers."""
@@ -478,13 +494,17 @@ def _prepare_qkv_shared(
478494

479495
if self.use_qk_norm and self.qk_norm_before_rope:
480496
q = self.q_norm_fn(q)
497+
if self.scale_query_by != 1.0:
498+
q = q * self.scale_query_by
481499

482500
# Apply RoPE to Q only (K already has RoPE from donor layer)
483501
q, _ = self.rope.forward(q, q, freqs_cos, freqs_sin)
484502
q = q.transpose(1, 2)
485503

486504
if self.use_qk_norm and not self.qk_norm_before_rope:
487505
q = self.q_norm_fn(q)
506+
if self.scale_query_by != 1.0:
507+
q = q * self.scale_query_by
488508

489509
return q, k, v
490510

@@ -508,6 +528,8 @@ def _prepare_qkv(
508528

509529
if self.use_qk_norm and self.qk_norm_before_rope:
510530
q = self.q_norm_fn(q)
531+
if self.scale_query_by != 1.0:
532+
q = q * self.scale_query_by
511533
k = self.k_norm_fn(k)
512534

513535
q, k = self.rope.forward(q, k, freqs_cos, freqs_sin)
@@ -518,6 +540,8 @@ def _prepare_qkv(
518540

519541
if self.use_qk_norm and not self.qk_norm_before_rope:
520542
q = self.q_norm_fn(q)
543+
if self.scale_query_by != 1.0:
544+
q = q * self.scale_query_by
521545
k = self.k_norm_fn(k)
522546

523547
return q, k, v
@@ -582,8 +606,7 @@ def forward(
582606
)
583607

584608
output = self.SDPA(input_pos, q, k, v, bsz, seqlen, attn_mask)
585-
if gate is not None:
586-
output = output * torch.sigmoid(gate)
609+
output = self._apply_output_transforms(output, x, gate, bsz, seqlen)
587610

588611
if shared_kv is None and self.num_kv_shared_layers > 0:
589612
update = {"kv_to_share": (k, v)}
@@ -602,13 +625,27 @@ def forward(
602625
output = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
603626

604627
output = output.transpose(1, 2).reshape(bsz, seqlen, -1)
605-
if gate is not None:
606-
output = output * torch.sigmoid(gate)
628+
output = self._apply_output_transforms(output, x, gate, bsz, seqlen)
607629

608630
output = self.wo(output)
609631

610632
return output, None
611633

634+
def _apply_output_transforms(
635+
self, output: torch.Tensor, x: torch.Tensor, gate, bsz: int, seqlen: int
636+
) -> torch.Tensor:
637+
if self.use_attn_o_norm or self.use_attn_o_gate:
638+
output_4d = output.view(bsz, seqlen, self.n_local_heads, self.head_dim)
639+
if self.use_attn_o_norm:
640+
output_4d = self.o_norm(output_4d)
641+
if self.use_attn_o_gate:
642+
og = self.og(x).view(bsz, seqlen, self.n_local_heads, self.head_dim)
643+
output_4d = torch.sigmoid(og) * output_4d
644+
output = output_4d.reshape(bsz, seqlen, -1)
645+
if gate is not None:
646+
output = output * torch.sigmoid(gate)
647+
return output
648+
612649

613650
def _l2norm(x: torch.Tensor, dim: int = -1, eps: float = 1e-6) -> torch.Tensor:
614651
inv_norm = torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps)

examples/models/llama/llama_transformer.py

Lines changed: 77 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
# Please refer to README.md in the same folder for more information.
99

10+
import math
1011
from typing import Any, Dict, Optional, Tuple, Union
1112

1213
import torch
@@ -20,7 +21,11 @@
2021
)
2122
from executorch.examples.models.llama.feed_forward import FeedForward, LoRAFeedForward
2223
from executorch.examples.models.llama.model_args import ModelArgs
23-
from executorch.examples.models.llama.norm import RMSNorm
24+
from executorch.examples.models.llama.norm import (
25+
RMSNorm,
26+
RMSNormWithInputScale,
27+
ScalelessRMSNorm,
28+
)
2429
from executorch.examples.models.llama.rope import Rope
2530
from torch import nn
2631

@@ -51,6 +56,26 @@ def _is_kv_shared_layer(
5156
return layer_idx >= first_shared and first_shared > 0
5257

5358

59+
class NormPreservingResidualConnection(nn.Module):
60+
def __init__(
61+
self, dim: int, init_scale: float, temperature: float = 0.3, eps: float = 1e-3
62+
):
63+
super().__init__()
64+
self.eps = eps
65+
self.temperature = temperature
66+
p = max(0.0 + eps, min(1.0 - eps, init_scale))
67+
init_param = math.log(p / (1.0 - p)) * temperature
68+
self.gate = nn.Parameter(torch.full((dim,), init_param))
69+
70+
def forward(self, stream: torch.Tensor, branch: torch.Tensor) -> torch.Tensor:
71+
dtype = stream.dtype
72+
w = self.gate.view(*([1] * (stream.ndim - 1)), -1).float()
73+
beta = torch.sigmoid(w / self.temperature)
74+
alpha_sq = torch.sigmoid(-w / self.temperature) * (1.0 + beta)
75+
alpha = torch.sqrt(torch.clamp(alpha_sq, min=self.eps))
76+
return (alpha * stream.float() + beta * branch.float()).to(dtype)
77+
78+
5479
class ConditionalFeedForward(nn.Module):
5580
def __init__(self, args: ModelArgs):
5681
super().__init__()
@@ -99,7 +124,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
99124

100125
class TransformerBlock(nn.Module):
101126
def __init__(
102-
self, args: ModelArgs, attention: Attention, mlp_type: str = "default"
127+
self,
128+
args: ModelArgs,
129+
attention: Attention,
130+
mlp_type: str = "default",
131+
layer_id: int = 0,
103132
):
104133
"""
105134
Transformer block with support for pre-norm and post-norm.
@@ -110,6 +139,7 @@ def __init__(
110139
the attention type is registered in the ATTENTION_REGISTRY.
111140
mlp_type (str): MLP type for this layer. "default" for standard
112141
FFN, "skip" for no FFN block.
142+
layer_id (int): layer index, used for residual gate initialization.
113143
"""
114144
super().__init__()
115145
self.use_kv_cache = args.use_kv_cache
@@ -118,6 +148,7 @@ def __init__(
118148
self.head_dim = args.head_dim
119149
self.attention = attention
120150
self.mlp_type = mlp_type.lower()
151+
self.use_residual_gate = args.use_residual_gate
121152

122153
assert (
123154
args.hidden_dim is not None
@@ -150,6 +181,20 @@ def __init__(
150181
add_unit_offset=args.rms_norm_add_unit_offset,
151182
)
152183

184+
if args.use_residual_gate:
185+
attn_init = 1.0 / (2 * layer_id + 1) if layer_id > 0 else 0.5
186+
ffn_init = 1.0 / (2 * layer_id + 2)
187+
self.add_attn = NormPreservingResidualConnection(
188+
dim=args.dim, init_scale=attn_init
189+
)
190+
self.add_ffn = NormPreservingResidualConnection(
191+
dim=args.dim, init_scale=ffn_init
192+
)
193+
self.post_attn_norm = ScalelessRMSNorm(args.dim, eps=args.norm_eps)
194+
195+
if args.use_ffn_learnable_scales and self.mlp_type != "skip":
196+
self.post_ffn_norm = RMSNormWithInputScale(args.dim, eps=args.norm_eps)
197+
153198
@classmethod
154199
def from_type(cls, layer_id, args, rope) -> "TransformerBlock":
155200
"""
@@ -169,21 +214,38 @@ def from_type(cls, layer_id, args, rope) -> "TransformerBlock":
169214
mlp_type = args.mlp_type[layer_id]
170215
cls = ATTENTION_REGISTRY[args.attention_type]
171216
attention = cls(args, layer_id, rope, **args.attention_kwargs)
172-
return TransformerBlock(args, attention, mlp_type=mlp_type)
217+
return TransformerBlock(args, attention, mlp_type=mlp_type, layer_id=layer_id)
173218

174219
def forward(self, x, freqs_cos, freqs_sin, attn_options: ForwardOptions): # x: 1xN
175220
h, attn_options_update = self.attention(
176221
self.attention_norm(x), freqs_cos, freqs_sin, **attn_options
177222
)
178223
if not isinstance(self.attention, AttentionSkip):
179-
h = x + h
224+
if self.use_residual_gate:
225+
if hasattr(self, "post_attn_norm"):
226+
h = self.post_attn_norm(h)
227+
h = self.add_attn(stream=x, branch=h)
228+
else:
229+
h = x + h
180230

181231
if self.mlp_type == "skip":
182232
out = h
183233
elif hasattr(self, "block_sparse_moe"):
184-
out = h + self.block_sparse_moe(self.ffn_norm(h))
234+
ffn_out = self.block_sparse_moe(self.ffn_norm(h))
235+
if hasattr(self, "post_ffn_norm"):
236+
ffn_out = self.post_ffn_norm(ffn_out)
237+
if self.use_residual_gate:
238+
out = self.add_ffn(stream=h, branch=ffn_out)
239+
else:
240+
out = h + ffn_out
185241
else:
186-
out = h + self.feed_forward(self.ffn_norm(h))
242+
ffn_out = self.feed_forward(self.ffn_norm(h))
243+
if hasattr(self, "post_ffn_norm"):
244+
ffn_out = self.post_ffn_norm(ffn_out)
245+
if self.use_residual_gate:
246+
out = self.add_ffn(stream=h, branch=ffn_out)
247+
else:
248+
out = h + ffn_out
187249
return out, attn_options_update
188250

189251

@@ -371,7 +433,9 @@ def construct_transformer(model_args: ModelArgs) -> Transformer:
371433
and model_args.layer_types[layer_id] == "skip_attention"
372434
):
373435
attention = AttentionSkip()
374-
transformer_block = TransformerBlock(model_args, attention)
436+
transformer_block = TransformerBlock(
437+
model_args, attention, layer_id=layer_id
438+
)
375439
layers.append(transformer_block)
376440
elif (
377441
model_args.layer_types
@@ -386,13 +450,17 @@ def construct_transformer(model_args: ModelArgs) -> Transformer:
386450
attention = linear_cls(
387451
model_args, layer_id, rope, **model_args.attention_kwargs
388452
)
389-
transformer_block = TransformerBlock(model_args, attention)
453+
transformer_block = TransformerBlock(
454+
model_args, attention, layer_id=layer_id
455+
)
390456
layers.append(transformer_block)
391457
else:
392458
attention = cls(
393459
model_args, layer_id, rope, **model_args.attention_kwargs
394460
) # pyre-ignore[45]
395-
transformer_block = TransformerBlock(model_args, attention)
461+
transformer_block = TransformerBlock(
462+
model_args, attention, layer_id=layer_id
463+
)
396464
layers.append(transformer_block)
397465

398466
return Transformer(model_args, layers, rope)

examples/models/llama/model_args.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,9 @@ class ModelArgs:
102102
apply_output: bool = True # Use output layer (unembedding) inside the transformer
103103
use_qk_norm: bool = False # apply normalization to q and k in the attention
104104
qk_norm_before_rope: bool = False # when to apply qk norm
105+
qk_norm_affine: bool = (
106+
True # whether QK norm has learnable weight (False = scaleless)
107+
)
105108
residual_multiplier: Optional[float] = (
106109
None # Scaling factor applied to the residual hidden states
107110
)
@@ -162,6 +165,15 @@ class ModelArgs:
162165
final_logit_softcapping: Optional[float] = None
163166
attn_logit_softcapping: Optional[float] = None
164167

168+
# rlformers forward-pass features for on-device model parity
169+
normalize_tok_embeddings: bool = False
170+
scale_query_by: float = 1.0
171+
use_attn_o_gate: bool = False
172+
use_attn_o_norm: bool = False
173+
use_residual_gate: bool = False
174+
use_ffn_learnable_scales: bool = False
175+
output_soft_cap_temp: Optional[float] = None
176+
165177
def __post_init__(self): # noqa: C901
166178
if self.n_kv_heads is None:
167179
self.n_kv_heads = self.n_heads

examples/models/llama/norm.py

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -32,33 +32,38 @@ def __init__(self, dim: int, eps: float = 1e-6, add_unit_offset: bool = False):
3232
self.weight = nn.Parameter(torch.ones(dim))
3333

3434
def _norm(self, x):
35-
"""
36-
Apply the RMSNorm normalization to the input tensor.
35+
return x * torch.rsqrt((x * x).mean(-1, keepdim=True) + self.eps)
3736

38-
Args:
39-
x (torch.Tensor): The input tensor.
37+
def forward(self, x):
38+
output = self._norm(x.float()).type_as(x)
39+
if self.add_unit_offset:
40+
return output * (1.0 + self.weight.float()).type_as(x)
41+
return output * self.weight.type_as(x)
4042

41-
Returns:
42-
torch.Tensor: The normalized tensor.
4343

44-
"""
45-
return x * torch.rsqrt((x * x).mean(-1, keepdim=True) + self.eps)
44+
class ScalelessRMSNorm(torch.nn.Module):
45+
def __init__(self, dim: int, eps: float = 1e-6):
46+
super().__init__()
47+
self.dim = dim
48+
self.eps = eps
4649

4750
def forward(self, x):
48-
"""
49-
Forward pass through the RMSNorm layer.
51+
x_float = x.float()
52+
return (
53+
x_float * torch.rsqrt((x_float * x_float).mean(-1, keepdim=True) + self.eps)
54+
).type_as(x)
5055

51-
Args:
52-
x (torch.Tensor): The input tensor.
5356

54-
Returns:
55-
torch.Tensor: The output tensor after applying RMSNorm.
57+
class RMSNormWithInputScale(torch.nn.Module):
58+
def __init__(self, dim: int, eps: float = 1e-5):
59+
super().__init__()
60+
self.eps = eps
61+
self.dim = dim
62+
self.weight = torch.nn.Parameter(torch.ones(dim))
5663

57-
"""
58-
output = self._norm(x.float()).type_as(x)
59-
if self.add_unit_offset:
60-
return output * (1.0 + self.weight.float()).type_as(x)
61-
return output * self.weight.type_as(x)
64+
def forward(self, x):
65+
scaled = self.weight * x
66+
return F.rms_norm(scaled, (self.dim,), None, self.eps)
6267

6368

6469
class RMSNormGated(nn.Module):

0 commit comments

Comments
 (0)