44Transformer layers for Wan2.1 DiT.
55
66Norms, attention, blocks, and output head. Uses bidirectional (non-causal)
7- attention with setattr-based block registration for weight remapping
8- compatibility.
7+ attention with fused norm+modulate via mx.fast.layer_norm.
98"""
109
1110import math
1817from .rope import rope_apply
1918
2019
21- @partial (mx .compile , shapeless = True )
22- def _modulate (x , scale , shift ):
23- return x * (1 + scale ) + shift
24-
25-
2620@partial (mx .compile , shapeless = True )
2721def _residual_gate (x , y , gate ):
2822 return x + y * gate
2923
3024
31- _gelu = mx .compile (nn .gelu_approx )
32-
33-
34- @partial (mx .compile , shapeless = True )
35- def _layer_norm (x , eps ):
36- mean = x .mean (axis = - 1 , keepdims = True )
37- var = x .var (axis = - 1 , keepdims = True )
38- return (x - mean ) / mx .sqrt (var + eps )
39-
40-
4125class WanRMSNorm (nn .Module ):
4226 def __init__ (self , dim : int , eps : float = 1e-5 ):
4327 super ().__init__ ()
@@ -48,22 +32,6 @@ def __call__(self, x: mx.array) -> mx.array:
4832 return mx .fast .rms_norm (x , self .weight , self .eps )
4933
5034
51- class WanLayerNorm (nn .Module ):
52- def __init__ (self , dim : int , eps : float = 1e-6 , elementwise_affine : bool = False ):
53- super ().__init__ ()
54- self .eps = eps
55- self .elementwise_affine = elementwise_affine
56- if elementwise_affine :
57- self .weight = mx .ones ((dim ,))
58- self .bias = mx .zeros ((dim ,))
59-
60- def __call__ (self , x : mx .array ) -> mx .array :
61- if self .elementwise_affine :
62- return mx .fast .layer_norm (x , self .weight , self .bias , self .eps )
63- else :
64- return _layer_norm (x , self .eps )
65-
66-
6735class WanSelfAttention (nn .Module ):
6836 def __init__ (
6937 self ,
@@ -213,8 +181,9 @@ class WanAttentionBlock(nn.Module):
213181 """
214182 Transformer block with self-attn, cross-attn, and FFN.
215183
216- Uses ffn_linear1/ffn_linear2 naming (not nn.Sequential) for weight
217- remapping compatibility and selective quantization.
184+ Uses fused norm+modulate via mx.fast.layer_norm where the modulation
185+ scale/shift are passed as weight/bias. Requires sanitize to bake 1+
186+ into modulation scale positions.
218187 """
219188
220189 def __init__ (
@@ -228,19 +197,21 @@ def __init__(
228197 ):
229198 super ().__init__ ()
230199 self .dim = dim
200+ self .eps = eps
231201
232- self .norm1 = WanLayerNorm (dim , eps )
233- self .norm2 = WanLayerNorm (dim , eps )
234202 if cross_attn_norm :
235- self .norm3 = WanLayerNorm (dim , eps , elementwise_affine = True )
203+ self .norm3 = nn . LayerNorm (dim , eps = eps )
236204 else :
237205 self .norm3 = None
238206
239207 self .self_attn = WanSelfAttention (dim , num_heads , eps )
240208 self .cross_attn = _cross_attn_classes [cross_attn_type ](dim , num_heads , eps )
241209
242- self .ffn_linear1 = nn .Linear (dim , ffn_dim )
243- self .ffn_linear2 = nn .Linear (ffn_dim , dim )
210+ self .ffn = nn .Sequential (
211+ nn .Linear (dim , ffn_dim ),
212+ nn .GELU (approx = "tanh" ),
213+ nn .Linear (ffn_dim , dim ),
214+ )
244215
245216 self .modulation = mx .zeros ((1 , 6 , dim ))
246217
@@ -255,10 +226,9 @@ def __call__(
255226 ) -> mx .array :
256227 e = self .modulation + e
257228
258- # Self-attention with modulation
259- x_norm = self .norm1 (x )
229+ # Self-attention: fused norm + modulate
260230 y = self .self_attn (
261- _modulate ( x_norm , e [: , 1 ], e [: , 0 ]),
231+ mx . fast . layer_norm ( x , e [0 , 1 ], e [0 , 0 ], self . eps ),
262232 grid_sizes ,
263233 freqs ,
264234 )
@@ -271,18 +241,15 @@ def __call__(
271241 x_normed = x
272242 x = x + self .cross_attn (x_normed , context , context_lens )
273243
274- # FFN with modulation
275- x_norm = self .norm2 (x )
276- y = self .ffn_linear2 (
277- _gelu (self .ffn_linear1 (_modulate (x_norm , e [:, 4 ], e [:, 3 ])))
278- )
244+ # FFN: fused norm + modulate
245+ y = self .ffn (mx .fast .layer_norm (x , e [0 , 4 ], e [0 , 3 ], self .eps ))
279246 x = _residual_gate (x , y , e [:, 5 ])
280247
281248 return x
282249
283250
284251class Head (nn .Module ):
285- """Output head with modulation. Uses raw weight arrays for remapping compat ."""
252+ """Output head with fused norm+modulate and nn.Linear ."""
286253
287254 def __init__ (
288255 self ,
@@ -293,23 +260,12 @@ def __init__(
293260 ):
294261 super ().__init__ ()
295262 self .dim = dim
263+ self .eps = eps
296264 out_features = math .prod (patch_size ) * out_dim
297- self .norm = WanLayerNorm (dim , eps )
298- scale = 1.0 / dim ** 0.5
299- self .head_weight = mx .random .uniform (
300- low = - scale , high = scale , shape = (out_features , dim )
301- )
302- self .head_bias = mx .zeros ((out_features ,))
265+ self .linear = nn .Linear (dim , out_features )
303266 self .modulation = mx .zeros ((1 , 2 , dim ))
304267
305268 def __call__ (self , x : mx .array , e : mx .array ) -> mx .array :
306269 e = self .modulation + e [:, None , :]
307- x_norm = self .norm (x )
308- x = (
309- mx .matmul (
310- _modulate (x_norm , e [:, 1 ], e [:, 0 ]),
311- self .head_weight .T ,
312- )
313- + self .head_bias
314- )
315- return x
270+ x = mx .fast .layer_norm (x , e [0 , 1 ], e [0 , 0 ], self .eps )
271+ return self .linear (x )
0 commit comments