@@ -98,26 +98,33 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
9898
9999
100100class TransformerBlock (nn .Module ):
101- def __init__ (self , args : ModelArgs , attention : Attention ):
101+ def __init__ (
102+ self , args : ModelArgs , attention : Attention , mlp_type : str = "default"
103+ ):
102104 """
103105 Transformer block with support for pre-norm and post-norm.
104106 Args:
105107 args (ModelArgs): model configuration parameters.
106108 attention (Attention): attention object to use in the transformer
107109 block. See `attention.py` for types of attention. Make sure
108110 the attention type is registered in the ATTENTION_REGISTRY.
111+ mlp_type (str): MLP type for this layer. "default" for standard
112+ FFN, "skip" for no FFN block.
109113 """
110114 super ().__init__ ()
111115 self .use_kv_cache = args .use_kv_cache
112116 self .n_heads = args .n_heads
113117 self .dim = args .dim
114118 self .head_dim = args .head_dim
115119 self .attention = attention
120+ self .mlp_type = mlp_type .lower ()
116121
117122 assert (
118123 args .hidden_dim is not None
119124 ), "`hidden_dim` must be set in ModelArgs to construct a TransformerBlock."
120- if args .moe :
125+ if self .mlp_type == "skip" :
126+ pass # No FFN block for this layer
127+ elif args .moe :
121128 self .block_sparse_moe = MOEFeedForward (args )
122129 elif args .target_modules is not None and (
123130 "down_proj" in args .target_modules
@@ -136,11 +143,12 @@ def __init__(self, args: ModelArgs, attention: Attention):
136143 eps = args .norm_eps ,
137144 add_unit_offset = args .rms_norm_add_unit_offset ,
138145 )
139- self .ffn_norm = RMSNorm (
140- args .dim ,
141- eps = args .norm_eps ,
142- add_unit_offset = args .rms_norm_add_unit_offset ,
143- )
146+ if self .mlp_type != "skip" :
147+ self .ffn_norm = RMSNorm (
148+ args .dim ,
149+ eps = args .norm_eps ,
150+ add_unit_offset = args .rms_norm_add_unit_offset ,
151+ )
144152
145153 @classmethod
146154 def from_type (cls , layer_id , args , rope ) -> "TransformerBlock" :
@@ -156,9 +164,12 @@ def from_type(cls, layer_id, args, rope) -> "TransformerBlock":
156164 f"Unknown attention type: { args .attention_type } . "
157165 f"Available: { list (ATTENTION_REGISTRY .keys ())} "
158166 )
167+ mlp_type = "default"
168+ if args .mlp_type is not None and layer_id < len (args .mlp_type ):
169+ mlp_type = args .mlp_type [layer_id ]
159170 cls = ATTENTION_REGISTRY [args .attention_type ]
160171 attention = cls (args , layer_id , rope , ** args .attention_kwargs )
161- return TransformerBlock (args , attention )
172+ return TransformerBlock (args , attention , mlp_type = mlp_type )
162173
163174 def forward (self , x , freqs_cos , freqs_sin , attn_options : ForwardOptions ): # x: 1xN
164175 h , attn_options_update = self .attention (
@@ -167,7 +178,9 @@ def forward(self, x, freqs_cos, freqs_sin, attn_options: ForwardOptions): # x:
167178 if not isinstance (self .attention , AttentionSkip ):
168179 h = x + h
169180
170- if hasattr (self , "block_sparse_moe" ):
181+ if self .mlp_type == "skip" :
182+ out = h
183+ elif hasattr (self , "block_sparse_moe" ):
171184 out = h + self .block_sparse_moe (self .ffn_norm (h ))
172185 else :
173186 out = h + self .feed_forward (self .ffn_norm (h ))
0 commit comments