77
88# Please refer to README.md in the same folder for more information.
99
10+ import math
1011from typing import Any , Dict , Optional , Tuple , Union
1112
1213import torch
2021)
2122from executorch .examples .models .llama .feed_forward import FeedForward , LoRAFeedForward
2223from executorch .examples .models .llama .model_args import ModelArgs
23- from executorch .examples .models .llama .norm import RMSNorm
24+ from executorch .examples .models .llama .norm import (
25+ RMSNorm ,
26+ RMSNormWithInputScale ,
27+ ScalelessRMSNorm ,
28+ )
2429from executorch .examples .models .llama .rope import Rope
2530from torch import nn
2631
@@ -51,6 +56,26 @@ def _is_kv_shared_layer(
5156 return layer_idx >= first_shared and first_shared > 0
5257
5358
59+ class NormPreservingResidualConnection (nn .Module ):
60+ def __init__ (
61+ self , dim : int , init_scale : float , temperature : float = 0.3 , eps : float = 1e-3
62+ ):
63+ super ().__init__ ()
64+ self .eps = eps
65+ self .temperature = temperature
66+ p = max (0.0 + eps , min (1.0 - eps , init_scale ))
67+ init_param = math .log (p / (1.0 - p )) * temperature
68+ self .gate = nn .Parameter (torch .full ((dim ,), init_param ))
69+
70+ def forward (self , stream : torch .Tensor , branch : torch .Tensor ) -> torch .Tensor :
71+ dtype = stream .dtype
72+ w = self .gate .view (* ([1 ] * (stream .ndim - 1 )), - 1 ).float ()
73+ beta = torch .sigmoid (w / self .temperature )
74+ alpha_sq = torch .sigmoid (- w / self .temperature ) * (1.0 + beta )
75+ alpha = torch .sqrt (torch .clamp (alpha_sq , min = self .eps ))
76+ return (alpha * stream .float () + beta * branch .float ()).to (dtype )
77+
78+
5479class ConditionalFeedForward (nn .Module ):
5580 def __init__ (self , args : ModelArgs ):
5681 super ().__init__ ()
@@ -99,7 +124,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
99124
100125class TransformerBlock (nn .Module ):
101126 def __init__ (
102- self , args : ModelArgs , attention : Attention , mlp_type : str = "default"
127+ self ,
128+ args : ModelArgs ,
129+ attention : Attention ,
130+ mlp_type : str = "default" ,
131+ layer_id : int = 0 ,
103132 ):
104133 """
105134 Transformer block with support for pre-norm and post-norm.
@@ -110,6 +139,7 @@ def __init__(
110139 the attention type is registered in the ATTENTION_REGISTRY.
111140 mlp_type (str): MLP type for this layer. "default" for standard
112141 FFN, "skip" for no FFN block.
142+ layer_id (int): layer index, used for residual gate initialization.
113143 """
114144 super ().__init__ ()
115145 self .use_kv_cache = args .use_kv_cache
@@ -118,6 +148,7 @@ def __init__(
118148 self .head_dim = args .head_dim
119149 self .attention = attention
120150 self .mlp_type = mlp_type .lower ()
151+ self .use_residual_gate = args .use_residual_gate
121152
122153 assert (
123154 args .hidden_dim is not None
@@ -150,6 +181,20 @@ def __init__(
150181 add_unit_offset = args .rms_norm_add_unit_offset ,
151182 )
152183
184+ if args .use_residual_gate :
185+ attn_init = 1.0 / (2 * layer_id + 1 ) if layer_id > 0 else 0.5
186+ ffn_init = 1.0 / (2 * layer_id + 2 )
187+ self .add_attn = NormPreservingResidualConnection (
188+ dim = args .dim , init_scale = attn_init
189+ )
190+ self .add_ffn = NormPreservingResidualConnection (
191+ dim = args .dim , init_scale = ffn_init
192+ )
193+ self .post_attn_norm = ScalelessRMSNorm (args .dim , eps = args .norm_eps )
194+
195+ if args .use_ffn_learnable_scales and self .mlp_type != "skip" :
196+ self .post_ffn_norm = RMSNormWithInputScale (args .dim , eps = args .norm_eps )
197+
153198 @classmethod
154199 def from_type (cls , layer_id , args , rope ) -> "TransformerBlock" :
155200 """
@@ -169,21 +214,38 @@ def from_type(cls, layer_id, args, rope) -> "TransformerBlock":
169214 mlp_type = args .mlp_type [layer_id ]
170215 cls = ATTENTION_REGISTRY [args .attention_type ]
171216 attention = cls (args , layer_id , rope , ** args .attention_kwargs )
172- return TransformerBlock (args , attention , mlp_type = mlp_type )
217+ return TransformerBlock (args , attention , mlp_type = mlp_type , layer_id = layer_id )
173218
174219 def forward (self , x , freqs_cos , freqs_sin , attn_options : ForwardOptions ): # x: 1xN
175220 h , attn_options_update = self .attention (
176221 self .attention_norm (x ), freqs_cos , freqs_sin , ** attn_options
177222 )
178223 if not isinstance (self .attention , AttentionSkip ):
179- h = x + h
224+ if self .use_residual_gate :
225+ if hasattr (self , "post_attn_norm" ):
226+ h = self .post_attn_norm (h )
227+ h = self .add_attn (stream = x , branch = h )
228+ else :
229+ h = x + h
180230
181231 if self .mlp_type == "skip" :
182232 out = h
183233 elif hasattr (self , "block_sparse_moe" ):
184- out = h + self .block_sparse_moe (self .ffn_norm (h ))
234+ ffn_out = self .block_sparse_moe (self .ffn_norm (h ))
235+ if hasattr (self , "post_ffn_norm" ):
236+ ffn_out = self .post_ffn_norm (ffn_out )
237+ if self .use_residual_gate :
238+ out = self .add_ffn (stream = h , branch = ffn_out )
239+ else :
240+ out = h + ffn_out
185241 else :
186- out = h + self .feed_forward (self .ffn_norm (h ))
242+ ffn_out = self .feed_forward (self .ffn_norm (h ))
243+ if hasattr (self , "post_ffn_norm" ):
244+ ffn_out = self .post_ffn_norm (ffn_out )
245+ if self .use_residual_gate :
246+ out = self .add_ffn (stream = h , branch = ffn_out )
247+ else :
248+ out = h + ffn_out
187249 return out , attn_options_update
188250
189251
@@ -371,7 +433,9 @@ def construct_transformer(model_args: ModelArgs) -> Transformer:
371433 and model_args .layer_types [layer_id ] == "skip_attention"
372434 ):
373435 attention = AttentionSkip ()
374- transformer_block = TransformerBlock (model_args , attention )
436+ transformer_block = TransformerBlock (
437+ model_args , attention , layer_id = layer_id
438+ )
375439 layers .append (transformer_block )
376440 elif (
377441 model_args .layer_types
@@ -386,13 +450,17 @@ def construct_transformer(model_args: ModelArgs) -> Transformer:
386450 attention = linear_cls (
387451 model_args , layer_id , rope , ** model_args .attention_kwargs
388452 )
389- transformer_block = TransformerBlock (model_args , attention )
453+ transformer_block = TransformerBlock (
454+ model_args , attention , layer_id = layer_id
455+ )
390456 layers .append (transformer_block )
391457 else :
392458 attention = cls (
393459 model_args , layer_id , rope , ** model_args .attention_kwargs
394460 ) # pyre-ignore[45]
395- transformer_block = TransformerBlock (model_args , attention )
461+ transformer_block = TransformerBlock (
462+ model_args , attention , layer_id = layer_id
463+ )
396464 layers .append (transformer_block )
397465
398466 return Transformer (model_args , layers , rope )
0 commit comments