Skip to content

Commit 89668b9

Browse files
author
Gleb Sterkin
committed
PR fixes, pt1
1 parent ddd567c commit 89668b9

10 files changed

Lines changed: 182 additions & 352 deletions

File tree

video/wan2.1/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ Hub](https://huggingface.co/Wan-AI).
77

88
| Model | Task | HF Repo | RAM (unquantized), 81 frames | Single DiT step on M4 Max chip, 81 frames |
99
|-------|------|---------|-----------------|---|
10-
| 1.3B | T2V | [Wan-AI/Wan2.1-T2V-1.3B](https://huggingface.co/Wan-AI/Wan2.1-T2V-1.3B) | ~10GB | ~100 s/it |
10+
| 1.3B | T2V | [Wan-AI/Wan2.1-T2V-1.3B](https://huggingface.co/Wan-AI/Wan2.1-T2V-1.3B) | ~10GB | ~90 s/it |
1111
| 14B | T2V | [Wan-AI/Wan2.1-T2V-14B](https://huggingface.co/Wan-AI/Wan2.1-T2V-14B) | ~36GB | ~230 s/it |
1212
| 14B | I2V | [Wan-AI/Wan2.1-I2V-14B-480P](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-480P) | ~39GB | ~250 s/it |
1313

video/wan2.1/img2video.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,6 @@ def quantization_predicate(name, m):
7171
)
7272
parser.add_argument("--output", default="out.mp4")
7373
parser.add_argument("--preload-models", action="store_true")
74-
parser.add_argument(
75-
"--compile-vae", action="store_true", help="Compile VAE decoder"
76-
)
7774
parser.add_argument(
7875
"--no-cache",
7976
action="store_true",
@@ -151,7 +148,7 @@ def quantization_predicate(name, m):
151148
mx.reset_peak_memory()
152149

153150
# 3. VAE decode
154-
video = pipeline.decode(x_t, compile_vae=args.compile_vae)
151+
video = pipeline.decode(x_t)
155152
mx.eval(video)
156153
peak_mem_decoding = mx.get_peak_memory() / 1024**3
157154

video/wan2.1/txt2video.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,6 @@ def quantization_predicate(name, m):
7070
)
7171
parser.add_argument("--output", default="out.mp4")
7272
parser.add_argument("--preload-models", action="store_true")
73-
parser.add_argument(
74-
"--compile-vae", action="store_true", help="Compile VAE decoder"
75-
)
7673
parser.add_argument(
7774
"--no-cache",
7875
action="store_true",
@@ -147,7 +144,7 @@ def quantization_predicate(name, m):
147144
mx.reset_peak_memory()
148145

149146
# 3. VAE decode
150-
video = pipeline.decode(x_t, compile_vae=args.compile_vae)
147+
video = pipeline.decode(x_t)
151148
mx.eval(video)
152149
peak_mem_decoding = mx.get_peak_memory() / 1024**3
153150

video/wan2.1/wan/layers.py

Lines changed: 20 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44
Transformer layers for Wan2.1 DiT.
55
66
Norms, attention, blocks, and output head. Uses bidirectional (non-causal)
7-
attention with setattr-based block registration for weight remapping
8-
compatibility.
7+
attention with fused norm+modulate via mx.fast.layer_norm.
98
"""
109

1110
import math
@@ -18,26 +17,11 @@
1817
from .rope import rope_apply
1918

2019

21-
@partial(mx.compile, shapeless=True)
22-
def _modulate(x, scale, shift):
23-
return x * (1 + scale) + shift
24-
25-
2620
@partial(mx.compile, shapeless=True)
2721
def _residual_gate(x, y, gate):
2822
return x + y * gate
2923

3024

31-
_gelu = mx.compile(nn.gelu_approx)
32-
33-
34-
@partial(mx.compile, shapeless=True)
35-
def _layer_norm(x, eps):
36-
mean = x.mean(axis=-1, keepdims=True)
37-
var = x.var(axis=-1, keepdims=True)
38-
return (x - mean) / mx.sqrt(var + eps)
39-
40-
4125
class WanRMSNorm(nn.Module):
4226
def __init__(self, dim: int, eps: float = 1e-5):
4327
super().__init__()
@@ -48,22 +32,6 @@ def __call__(self, x: mx.array) -> mx.array:
4832
return mx.fast.rms_norm(x, self.weight, self.eps)
4933

5034

51-
class WanLayerNorm(nn.Module):
52-
def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine: bool = False):
53-
super().__init__()
54-
self.eps = eps
55-
self.elementwise_affine = elementwise_affine
56-
if elementwise_affine:
57-
self.weight = mx.ones((dim,))
58-
self.bias = mx.zeros((dim,))
59-
60-
def __call__(self, x: mx.array) -> mx.array:
61-
if self.elementwise_affine:
62-
return mx.fast.layer_norm(x, self.weight, self.bias, self.eps)
63-
else:
64-
return _layer_norm(x, self.eps)
65-
66-
6735
class WanSelfAttention(nn.Module):
6836
def __init__(
6937
self,
@@ -213,8 +181,9 @@ class WanAttentionBlock(nn.Module):
213181
"""
214182
Transformer block with self-attn, cross-attn, and FFN.
215183
216-
Uses ffn_linear1/ffn_linear2 naming (not nn.Sequential) for weight
217-
remapping compatibility and selective quantization.
184+
Uses fused norm+modulate via mx.fast.layer_norm where the modulation
185+
scale/shift are passed as weight/bias. Requires sanitize to bake 1+
186+
into modulation scale positions.
218187
"""
219188

220189
def __init__(
@@ -228,19 +197,21 @@ def __init__(
228197
):
229198
super().__init__()
230199
self.dim = dim
200+
self.eps = eps
231201

232-
self.norm1 = WanLayerNorm(dim, eps)
233-
self.norm2 = WanLayerNorm(dim, eps)
234202
if cross_attn_norm:
235-
self.norm3 = WanLayerNorm(dim, eps, elementwise_affine=True)
203+
self.norm3 = nn.LayerNorm(dim, eps=eps)
236204
else:
237205
self.norm3 = None
238206

239207
self.self_attn = WanSelfAttention(dim, num_heads, eps)
240208
self.cross_attn = _cross_attn_classes[cross_attn_type](dim, num_heads, eps)
241209

242-
self.ffn_linear1 = nn.Linear(dim, ffn_dim)
243-
self.ffn_linear2 = nn.Linear(ffn_dim, dim)
210+
self.ffn = nn.Sequential(
211+
nn.Linear(dim, ffn_dim),
212+
nn.GELU(approx="tanh"),
213+
nn.Linear(ffn_dim, dim),
214+
)
244215

245216
self.modulation = mx.zeros((1, 6, dim))
246217

@@ -255,10 +226,9 @@ def __call__(
255226
) -> mx.array:
256227
e = self.modulation + e
257228

258-
# Self-attention with modulation
259-
x_norm = self.norm1(x)
229+
# Self-attention: fused norm + modulate
260230
y = self.self_attn(
261-
_modulate(x_norm, e[:, 1], e[:, 0]),
231+
mx.fast.layer_norm(x, e[0, 1], e[0, 0], self.eps),
262232
grid_sizes,
263233
freqs,
264234
)
@@ -271,18 +241,15 @@ def __call__(
271241
x_normed = x
272242
x = x + self.cross_attn(x_normed, context, context_lens)
273243

274-
# FFN with modulation
275-
x_norm = self.norm2(x)
276-
y = self.ffn_linear2(
277-
_gelu(self.ffn_linear1(_modulate(x_norm, e[:, 4], e[:, 3])))
278-
)
244+
# FFN: fused norm + modulate
245+
y = self.ffn(mx.fast.layer_norm(x, e[0, 4], e[0, 3], self.eps))
279246
x = _residual_gate(x, y, e[:, 5])
280247

281248
return x
282249

283250

284251
class Head(nn.Module):
285-
"""Output head with modulation. Uses raw weight arrays for remapping compat."""
252+
"""Output head with fused norm+modulate and nn.Linear."""
286253

287254
def __init__(
288255
self,
@@ -293,23 +260,12 @@ def __init__(
293260
):
294261
super().__init__()
295262
self.dim = dim
263+
self.eps = eps
296264
out_features = math.prod(patch_size) * out_dim
297-
self.norm = WanLayerNorm(dim, eps)
298-
scale = 1.0 / dim**0.5
299-
self.head_weight = mx.random.uniform(
300-
low=-scale, high=scale, shape=(out_features, dim)
301-
)
302-
self.head_bias = mx.zeros((out_features,))
265+
self.linear = nn.Linear(dim, out_features)
303266
self.modulation = mx.zeros((1, 2, dim))
304267

305268
def __call__(self, x: mx.array, e: mx.array) -> mx.array:
306269
e = self.modulation + e[:, None, :]
307-
x_norm = self.norm(x)
308-
x = (
309-
mx.matmul(
310-
_modulate(x_norm, e[:, 1], e[:, 0]),
311-
self.head_weight.T,
312-
)
313-
+ self.head_bias
314-
)
315-
return x
270+
x = mx.fast.layer_norm(x, e[0, 1], e[0, 0], self.eps)
271+
return self.linear(x)

0 commit comments

Comments
 (0)